mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 11:19:02 +08:00
feat: add LocalAI local embedding model support (#1021)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
parent
b5953039de
commit
417c19577a
@ -63,6 +63,9 @@ class ModelProviderFactory:
|
|||||||
elif provider_name == 'openllm':
|
elif provider_name == 'openllm':
|
||||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||||
return OpenLLMProvider
|
return OpenLLMProvider
|
||||||
|
elif provider_name == 'localai':
|
||||||
|
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||||
|
return LocalAIProvider
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
from langchain.embeddings import LocalAIEmbeddings
|
||||||
|
|
||||||
|
from replicate.exceptions import ModelError, ReplicateError
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIEmbedding(BaseEmbedding):
|
||||||
|
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||||
|
credentials = model_provider.get_model_credentials(
|
||||||
|
model_name=name,
|
||||||
|
model_type=self.type
|
||||||
|
)
|
||||||
|
|
||||||
|
client = LocalAIEmbeddings(
|
||||||
|
model=name,
|
||||||
|
openai_api_key="1",
|
||||||
|
openai_api_base=credentials['server_url'],
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(model_provider, client, name)
|
||||||
|
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
if isinstance(ex, (ModelError, ReplicateError)):
|
||||||
|
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
|
||||||
|
else:
|
||||||
|
return ex
|
131
api/core/model_providers/models/llm/localai_model.py
Normal file
131
api/core/model_providers/models/llm/localai_model.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.schema import LLMResult, get_buffer_string
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||||
|
LLMRateLimitError, LLMAuthorizationError
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||||
|
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||||
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
|
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIModel(BaseLLM):
|
||||||
|
def __init__(self, model_provider: BaseModelProvider,
|
||||||
|
name: str,
|
||||||
|
model_kwargs: ModelKwargs,
|
||||||
|
streaming: bool = False,
|
||||||
|
callbacks: Callbacks = None):
|
||||||
|
credentials = model_provider.get_model_credentials(
|
||||||
|
model_name=name,
|
||||||
|
model_type=self.type
|
||||||
|
)
|
||||||
|
|
||||||
|
if credentials['completion_type'] == 'chat_completion':
|
||||||
|
self.model_mode = ModelMode.CHAT
|
||||||
|
else:
|
||||||
|
self.model_mode = ModelMode.COMPLETION
|
||||||
|
|
||||||
|
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||||
|
|
||||||
|
def _init_client(self) -> Any:
|
||||||
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
|
if self.model_mode == ModelMode.COMPLETION:
|
||||||
|
client = EnhanceOpenAI(
|
||||||
|
model_name=self.name,
|
||||||
|
streaming=self.streaming,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
request_timeout=60,
|
||||||
|
openai_api_key="1",
|
||||||
|
openai_api_base=self.credentials['server_url'] + '/v1',
|
||||||
|
**provider_model_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extra_model_kwargs = {
|
||||||
|
'top_p': provider_model_kwargs.get('top_p')
|
||||||
|
}
|
||||||
|
|
||||||
|
client = EnhanceChatOpenAI(
|
||||||
|
model_name=self.name,
|
||||||
|
temperature=provider_model_kwargs.get('temperature'),
|
||||||
|
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||||
|
model_kwargs=extra_model_kwargs,
|
||||||
|
streaming=self.streaming,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
request_timeout=60,
|
||||||
|
openai_api_key="1",
|
||||||
|
openai_api_base=self.credentials['server_url'] + '/v1'
|
||||||
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
def _run(self, messages: List[PromptMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs) -> LLMResult:
|
||||||
|
"""
|
||||||
|
run predict by prompt messages and stop words.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
:param stop:
|
||||||
|
:param callbacks:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
return self._client.generate([prompts], stop, callbacks)
|
||||||
|
|
||||||
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||||
|
"""
|
||||||
|
get num tokens of prompt messages.
|
||||||
|
|
||||||
|
:param messages:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
return self._client.get_num_tokens(prompts)
|
||||||
|
else:
|
||||||
|
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
|
||||||
|
|
||||||
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
|
if self.model_mode == ModelMode.COMPLETION:
|
||||||
|
for k, v in provider_model_kwargs.items():
|
||||||
|
if hasattr(self.client, k):
|
||||||
|
setattr(self.client, k, v)
|
||||||
|
else:
|
||||||
|
extra_model_kwargs = {
|
||||||
|
'top_p': provider_model_kwargs.get('top_p')
|
||||||
|
}
|
||||||
|
|
||||||
|
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||||
|
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||||
|
self.client.model_kwargs = extra_model_kwargs
|
||||||
|
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
if isinstance(ex, openai.error.InvalidRequestError):
|
||||||
|
logging.warning("Invalid request to LocalAI API.")
|
||||||
|
return LLMBadRequestError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.APIConnectionError):
|
||||||
|
logging.warning("Failed to connect to LocalAI API.")
|
||||||
|
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||||
|
logging.warning("LocalAI service unavailable.")
|
||||||
|
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
elif isinstance(ex, openai.error.RateLimitError):
|
||||||
|
return LLMRateLimitError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.AuthenticationError):
|
||||||
|
return LLMAuthorizationError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.OpenAIError):
|
||||||
|
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
else:
|
||||||
|
return ex
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_streaming(cls):
|
||||||
|
return True
|
164
api/core/model_providers/providers/localai_provider.py
Normal file
164
api/core/model_providers/providers/localai_provider.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import json
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.embeddings import LocalAIEmbeddings
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
|
||||||
|
from core.model_providers.models.llm.localai_model import LocalAIModel
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||||
|
|
||||||
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||||
|
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
|
||||||
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIProvider(BaseModelProvider):
|
||||||
|
@property
|
||||||
|
def provider_name(self):
|
||||||
|
"""
|
||||||
|
Returns the name of a provider.
|
||||||
|
"""
|
||||||
|
return 'localai'
|
||||||
|
|
||||||
|
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||||
|
"""
|
||||||
|
Returns the model class.
|
||||||
|
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
|
model_class = LocalAIModel
|
||||||
|
elif model_type == ModelType.EMBEDDINGS:
|
||||||
|
model_class = LocalAIEmbedding
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||||
|
"""
|
||||||
|
get model parameter rules.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ModelKwargsRules(
|
||||||
|
temperature=KwargRule[float](min=0, max=2, default=0.7),
|
||||||
|
top_p=KwargRule[float](min=0, max=1, default=1),
|
||||||
|
max_tokens=KwargRule[int](min=10, max=4097, default=16),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||||
|
"""
|
||||||
|
check model credentials valid.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
"""
|
||||||
|
if 'server_url' not in credentials:
|
||||||
|
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_type == ModelType.EMBEDDINGS:
|
||||||
|
model = LocalAIEmbeddings(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_key='1',
|
||||||
|
openai_api_base=credentials['server_url']
|
||||||
|
)
|
||||||
|
|
||||||
|
model.embed_query("ping")
|
||||||
|
else:
|
||||||
|
if ('completion_type' not in credentials
|
||||||
|
or credentials['completion_type'] not in ['completion', 'chat_completion']):
|
||||||
|
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
|
||||||
|
|
||||||
|
if credentials['completion_type'] == 'chat_completion':
|
||||||
|
model = EnhanceChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
openai_api_key='1',
|
||||||
|
openai_api_base=credentials['server_url'] + '/v1',
|
||||||
|
max_tokens=10,
|
||||||
|
request_timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
model([HumanMessage(content='ping')])
|
||||||
|
else:
|
||||||
|
model = EnhanceOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
openai_api_key='1',
|
||||||
|
openai_api_base=credentials['server_url'] + '/v1',
|
||||||
|
max_tokens=10,
|
||||||
|
request_timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
model('ping')
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||||
|
credentials: dict) -> dict:
|
||||||
|
"""
|
||||||
|
encrypt model credentials for save.
|
||||||
|
|
||||||
|
:param tenant_id:
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
get credentials for llm use.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param obfuscated:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.provider.provider_type != ProviderType.CUSTOM.value:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
provider_model = self._get_provider_model(model_name, model_type)
|
||||||
|
|
||||||
|
if not provider_model.encrypted_config:
|
||||||
|
return {
|
||||||
|
'server_url': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials = json.loads(provider_model.encrypted_config)
|
||||||
|
if credentials['server_url']:
|
||||||
|
credentials['server_url'] = encrypter.decrypt_token(
|
||||||
|
self.provider.tenant_id,
|
||||||
|
credentials['server_url']
|
||||||
|
)
|
||||||
|
|
||||||
|
if obfuscated:
|
||||||
|
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||||
|
return {}
|
@ -10,5 +10,6 @@
|
|||||||
"replicate",
|
"replicate",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"xinference",
|
"xinference",
|
||||||
"openllm"
|
"openllm",
|
||||||
|
"localai"
|
||||||
]
|
]
|
7
api/core/model_providers/rules/localai.json
Normal file
7
api/core/model_providers/rules/localai.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"support_provider_types": [
|
||||||
|
"custom"
|
||||||
|
],
|
||||||
|
"system_config": null,
|
||||||
|
"model_flexibility": "configurable"
|
||||||
|
}
|
@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
|
|||||||
return {
|
return {
|
||||||
**super()._default_params,
|
**super()._default_params,
|
||||||
"api_type": 'openai',
|
"api_type": 'openai',
|
||||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
"api_base": self.openai_api_base if self.openai_api_base
|
||||||
|
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
"api_key": self.openai_api_key,
|
"api_key": self.openai_api_key,
|
||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
|
35
api/core/third_party/langchain/llms/open_ai.py
vendored
35
api/core/third_party/langchain/llms/open_ai.py
vendored
@ -1,7 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Dict, Any, Mapping, Optional, Union, Tuple
|
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
|
||||||
from langchain import OpenAI
|
from langchain import OpenAI
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
|
|||||||
def _invocation_params(self) -> Dict[str, Any]:
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
return {**super()._invocation_params, **{
|
return {**super()._invocation_params, **{
|
||||||
"api_type": 'openai',
|
"api_type": 'openai',
|
||||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
"api_base": self.openai_api_base if self.openai_api_base
|
||||||
|
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
"api_key": self.openai_api_key,
|
"api_key": self.openai_api_key,
|
||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
|
|||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
return {**super()._identifying_params, **{
|
return {**super()._identifying_params, **{
|
||||||
"api_type": 'openai',
|
"api_type": 'openai',
|
||||||
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
"api_base": self.openai_api_base if self.openai_api_base
|
||||||
|
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||||
"api_version": None,
|
"api_version": None,
|
||||||
"api_key": self.openai_api_key,
|
"api_key": self.openai_api_key,
|
||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||||
|
self.get_sub_prompts(params, [prompt], stop) # this mutates params
|
||||||
|
for stream_resp in completion_with_retry(
|
||||||
|
self, prompt=prompt, run_manager=run_manager, **params
|
||||||
|
):
|
||||||
|
if 'text' in stream_resp["choices"][0]:
|
||||||
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
chunk.text,
|
||||||
|
verbose=self.verbose,
|
||||||
|
logprobs=chunk.generation_info["logprobs"]
|
||||||
|
if chunk.generation_info
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
@ -40,3 +40,6 @@ XINFERENCE_MODEL_UID=
|
|||||||
|
|
||||||
# OpenLLM Credentials
|
# OpenLLM Credentials
|
||||||
OPENLLM_SERVER_URL=
|
OPENLLM_SERVER_URL=
|
||||||
|
|
||||||
|
# LocalAI Credentials
|
||||||
|
LOCALAI_SERVER_URL=
|
@ -0,0 +1,61 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
|
||||||
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
|
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||||
|
from models.provider import Provider, ProviderType, ProviderModel
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_provider():
|
||||||
|
return Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name='localai',
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config='',
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_embedding_model(mocker):
|
||||||
|
model_name = 'text-embedding-ada-002'
|
||||||
|
server_url = os.environ['LOCALAI_SERVER_URL']
|
||||||
|
model_provider = LocalAIProvider(provider=get_mock_provider())
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
provider_name='localai',
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=ModelType.EMBEDDINGS.value,
|
||||||
|
encrypted_config=json.dumps({
|
||||||
|
'server_url': server_url,
|
||||||
|
}),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
return LocalAIEmbedding(
|
||||||
|
model_provider=model_provider,
|
||||||
|
name=model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||||
|
return encrypted_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_embed_documents(mock_decrypt, mocker):
|
||||||
|
embedding_model = get_mock_embedding_model(mocker)
|
||||||
|
rst = embedding_model.client.embed_documents(['test', 'test1'])
|
||||||
|
assert isinstance(rst, list)
|
||||||
|
assert len(rst) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_embed_query(mock_decrypt, mocker):
|
||||||
|
embedding_model = get_mock_embedding_model(mocker)
|
||||||
|
rst = embedding_model.client.embed_query('test')
|
||||||
|
assert isinstance(rst, list)
|
68
api/tests/integration_tests/models/llm/test_localai_model.py
Normal file
68
api/tests/integration_tests/models/llm/test_localai_model.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from core.model_providers.models.llm.localai_model import LocalAIModel
|
||||||
|
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||||
|
from core.model_providers.models.entity.message import PromptMessage
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
|
||||||
|
from models.provider import Provider, ProviderType, ProviderModel
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_provider(server_url):
|
||||||
|
return Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name='localai',
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=json.dumps({}),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_model(model_name, mocker):
|
||||||
|
model_kwargs = ModelKwargs(
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0
|
||||||
|
)
|
||||||
|
server_url = os.environ['LOCALAI_SERVER_URL']
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
provider_name='localai',
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=ModelType.TEXT_GENERATION.value,
|
||||||
|
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
|
||||||
|
return LocalAIModel(
|
||||||
|
model_provider=openai_provider,
|
||||||
|
name=model_name,
|
||||||
|
model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
|
||||||
|
return encrypted_openai_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_num_tokens(mock_decrypt, mocker):
|
||||||
|
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
|
||||||
|
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
|
||||||
|
assert rst > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_run(mock_decrypt, mocker):
|
||||||
|
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||||
|
|
||||||
|
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
|
||||||
|
rst = openai_model.run(
|
||||||
|
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
|
||||||
|
stop=['\nHuman:'],
|
||||||
|
)
|
||||||
|
assert len(rst.content) > 0
|
116
api/tests/unit_tests/model_providers/test_localai_provider.py
Normal file
116
api/tests/unit_tests/model_providers/test_localai_provider.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import json
|
||||||
|
|
||||||
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
|
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||||
|
from core.model_providers.providers.localai_provider import LocalAIProvider
|
||||||
|
from models.provider import ProviderType, Provider, ProviderModel
|
||||||
|
|
||||||
|
PROVIDER_NAME = 'localai'
|
||||||
|
MODEL_PROVIDER_CLASS = LocalAIProvider
|
||||||
|
VALIDATE_CREDENTIAL = {
|
||||||
|
'server_url': 'http://127.0.0.1:8080/'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||||
|
return f'encrypted_{encrypt_key}'
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||||
|
return encrypted_key.replace('encrypted_', '')
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||||
|
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
|
||||||
|
return_value="abc")
|
||||||
|
|
||||||
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
|
model_name='username/test_model_name',
|
||||||
|
model_type=ModelType.EMBEDDINGS,
|
||||||
|
credentials=VALIDATE_CREDENTIAL.copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_credentials_valid_or_raise_invalid():
|
||||||
|
# raise CredentialsValidateFailedError if server_url is not in credentials
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.EMBEDDINGS,
|
||||||
|
credentials={}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||||
|
def test_encrypt_model_credentials(mock_encrypt, mocker):
|
||||||
|
server_url = 'http://127.0.0.1:8080/'
|
||||||
|
|
||||||
|
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.EMBEDDINGS,
|
||||||
|
credentials=VALIDATE_CREDENTIAL.copy()
|
||||||
|
)
|
||||||
|
mock_encrypt.assert_called_with('tenant_id', server_url)
|
||||||
|
assert result['server_url'] == f'encrypted_{server_url}'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_model_credentials_custom(mock_decrypt, mocker):
|
||||||
|
provider = Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name=PROVIDER_NAME,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=None,
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
encrypted_config=json.dumps(encrypted_credential)
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||||
|
result = model_provider.get_model_credentials(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.EMBEDDINGS
|
||||||
|
)
|
||||||
|
assert result['server_url'] == 'http://127.0.0.1:8080/'
|
||||||
|
|
||||||
|
|
||||||
|
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||||
|
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
|
||||||
|
provider = Provider(
|
||||||
|
id='provider_id',
|
||||||
|
tenant_id='tenant_id',
|
||||||
|
provider_name=PROVIDER_NAME,
|
||||||
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
|
encrypted_config=None,
|
||||||
|
is_valid=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||||
|
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
|
||||||
|
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||||
|
encrypted_config=json.dumps(encrypted_credential)
|
||||||
|
)
|
||||||
|
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||||
|
|
||||||
|
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||||
|
result = model_provider.get_model_credentials(
|
||||||
|
model_name='test_model_name',
|
||||||
|
model_type=ModelType.EMBEDDINGS,
|
||||||
|
obfuscated=True
|
||||||
|
)
|
||||||
|
middle_token = result['server_url'][6:-2]
|
||||||
|
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
|
||||||
|
assert all(char == '*' for char in middle_token)
|
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 76 KiB |
15
web/app/components/base/icons/assets/public/llm/localai.svg
Normal file
15
web/app/components/base/icons/assets/public/llm/localai.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 73 KiB |
107
web/app/components/base/icons/src/public/llm/Localai.json
Normal file
107
web/app/components/base/icons/src/public/llm/Localai.json
Normal file
File diff suppressed because one or more lines are too long
14
web/app/components/base/icons/src/public/llm/Localai.tsx
Normal file
14
web/app/components/base/icons/src/public/llm/Localai.tsx
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// GENERATE BY script
|
||||||
|
// DON NOT EDIT IT MANUALLY
|
||||||
|
|
||||||
|
import * as React from 'react'
|
||||||
|
import data from './Localai.json'
|
||||||
|
import IconBase from '@/app/components/base/icons/IconBase'
|
||||||
|
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
|
||||||
|
|
||||||
|
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
|
||||||
|
props,
|
||||||
|
ref,
|
||||||
|
) => <IconBase {...props} ref={ref} data={data as IconData} />)
|
||||||
|
|
||||||
|
export default Icon
|
170
web/app/components/base/icons/src/public/llm/LocalaiText.json
Normal file
170
web/app/components/base/icons/src/public/llm/LocalaiText.json
Normal file
File diff suppressed because one or more lines are too long
14
web/app/components/base/icons/src/public/llm/LocalaiText.tsx
Normal file
14
web/app/components/base/icons/src/public/llm/LocalaiText.tsx
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// GENERATE BY script
|
||||||
|
// DON NOT EDIT IT MANUALLY
|
||||||
|
|
||||||
|
import * as React from 'react'
|
||||||
|
import data from './LocalaiText.json'
|
||||||
|
import IconBase from '@/app/components/base/icons/IconBase'
|
||||||
|
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
|
||||||
|
|
||||||
|
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
|
||||||
|
props,
|
||||||
|
ref,
|
||||||
|
) => <IconBase {...props} ref={ref} data={data as IconData} />)
|
||||||
|
|
||||||
|
export default Icon
|
@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface'
|
|||||||
export { default as IflytekSparkTextCn } from './IflytekSparkTextCn'
|
export { default as IflytekSparkTextCn } from './IflytekSparkTextCn'
|
||||||
export { default as IflytekSparkText } from './IflytekSparkText'
|
export { default as IflytekSparkText } from './IflytekSparkText'
|
||||||
export { default as IflytekSpark } from './IflytekSpark'
|
export { default as IflytekSpark } from './IflytekSpark'
|
||||||
|
export { default as LocalaiText } from './LocalaiText'
|
||||||
|
export { default as Localai } from './Localai'
|
||||||
export { default as Microsoft } from './Microsoft'
|
export { default as Microsoft } from './Microsoft'
|
||||||
export { default as OpenaiBlack } from './OpenaiBlack'
|
export { default as OpenaiBlack } from './OpenaiBlack'
|
||||||
export { default as OpenaiBlue } from './OpenaiBlue'
|
export { default as OpenaiBlue } from './OpenaiBlue'
|
||||||
|
@ -10,6 +10,7 @@ import minimax from './minimax'
|
|||||||
import chatglm from './chatglm'
|
import chatglm from './chatglm'
|
||||||
import xinference from './xinference'
|
import xinference from './xinference'
|
||||||
import openllm from './openllm'
|
import openllm from './openllm'
|
||||||
|
import localai from './localai'
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
openai,
|
openai,
|
||||||
@ -24,4 +25,5 @@ export default {
|
|||||||
chatglm,
|
chatglm,
|
||||||
xinference,
|
xinference,
|
||||||
openllm,
|
openllm,
|
||||||
|
localai,
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,176 @@
|
|||||||
|
import { ProviderEnum } from '../declarations'
|
||||||
|
import type { FormValue, ProviderConfig } from '../declarations'
|
||||||
|
import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm'
|
||||||
|
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
selector: {
|
||||||
|
name: {
|
||||||
|
'en': 'LocalAI',
|
||||||
|
'zh-Hans': 'LocalAI',
|
||||||
|
},
|
||||||
|
icon: <Localai className='w-full h-full' />,
|
||||||
|
},
|
||||||
|
item: {
|
||||||
|
key: ProviderEnum.localai,
|
||||||
|
titleIcon: {
|
||||||
|
'en': <LocalaiText className='h-6' />,
|
||||||
|
'zh-Hans': <LocalaiText className='h-6' />,
|
||||||
|
},
|
||||||
|
disable: {
|
||||||
|
tip: {
|
||||||
|
'en': 'Only supports the ',
|
||||||
|
'zh-Hans': '仅支持',
|
||||||
|
},
|
||||||
|
link: {
|
||||||
|
href: {
|
||||||
|
'en': 'https://docs.dify.ai/getting-started/install-self-hosted',
|
||||||
|
'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted',
|
||||||
|
},
|
||||||
|
label: {
|
||||||
|
'en': 'community open-source version',
|
||||||
|
'zh-Hans': '社区开源版本',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
modal: {
|
||||||
|
key: ProviderEnum.localai,
|
||||||
|
title: {
|
||||||
|
'en': 'LocalAI',
|
||||||
|
'zh-Hans': 'LocalAI',
|
||||||
|
},
|
||||||
|
icon: <Localai className='h-6' />,
|
||||||
|
link: {
|
||||||
|
href: 'https://github.com/go-skynet/LocalAI',
|
||||||
|
label: {
|
||||||
|
'en': 'How to deploy LocalAI',
|
||||||
|
'zh-Hans': '如何部署 LocalAI',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
defaultValue: {
|
||||||
|
model_type: 'text-generation',
|
||||||
|
completion_type: 'completion',
|
||||||
|
},
|
||||||
|
validateKeys: (v?: FormValue) => {
|
||||||
|
if (v?.model_type === 'text-generation') {
|
||||||
|
return [
|
||||||
|
'model_type',
|
||||||
|
'model_name',
|
||||||
|
'server_url',
|
||||||
|
'completion_type',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
if (v?.model_type === 'embeddings') {
|
||||||
|
return [
|
||||||
|
'model_type',
|
||||||
|
'model_name',
|
||||||
|
'server_url',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return []
|
||||||
|
},
|
||||||
|
filterValue: (v?: FormValue) => {
|
||||||
|
let filteredKeys: string[] = []
|
||||||
|
if (v?.model_type === 'text-generation') {
|
||||||
|
filteredKeys = [
|
||||||
|
'model_type',
|
||||||
|
'model_name',
|
||||||
|
'server_url',
|
||||||
|
'completion_type',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
if (v?.model_type === 'embeddings') {
|
||||||
|
filteredKeys = [
|
||||||
|
'model_type',
|
||||||
|
'model_name',
|
||||||
|
'server_url',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return filteredKeys.reduce((prev: FormValue, next: string) => {
|
||||||
|
prev[next] = v?.[next] || ''
|
||||||
|
return prev
|
||||||
|
}, {})
|
||||||
|
},
|
||||||
|
fields: [
|
||||||
|
{
|
||||||
|
type: 'radio',
|
||||||
|
key: 'model_type',
|
||||||
|
required: true,
|
||||||
|
label: {
|
||||||
|
'en': 'Model Type',
|
||||||
|
'zh-Hans': '模型类型',
|
||||||
|
},
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
key: 'text-generation',
|
||||||
|
label: {
|
||||||
|
'en': 'Text Generation',
|
||||||
|
'zh-Hans': '文本生成',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'embeddings',
|
||||||
|
label: {
|
||||||
|
'en': 'Embeddings',
|
||||||
|
'zh-Hans': 'Embeddings',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
key: 'model_name',
|
||||||
|
required: true,
|
||||||
|
label: {
|
||||||
|
'en': 'Model Name',
|
||||||
|
'zh-Hans': '模型名称',
|
||||||
|
},
|
||||||
|
placeholder: {
|
||||||
|
'en': 'Enter your Model Name here',
|
||||||
|
'zh-Hans': '在此输入您的模型名称',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hidden: (value?: FormValue) => value?.model_type === 'embeddings',
|
||||||
|
type: 'radio',
|
||||||
|
key: 'completion_type',
|
||||||
|
required: true,
|
||||||
|
label: {
|
||||||
|
'en': 'Completion Type',
|
||||||
|
'zh-Hans': 'Completion Type',
|
||||||
|
},
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
key: 'completion',
|
||||||
|
label: {
|
||||||
|
'en': 'Completion',
|
||||||
|
'zh-Hans': 'Completion',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'chat_completion',
|
||||||
|
label: {
|
||||||
|
'en': 'Chat Completion',
|
||||||
|
'zh-Hans': 'Chat Completion',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
key: 'server_url',
|
||||||
|
required: true,
|
||||||
|
label: {
|
||||||
|
'en': 'Server url',
|
||||||
|
'zh-Hans': 'Server url',
|
||||||
|
},
|
||||||
|
placeholder: {
|
||||||
|
'en': 'Enter your Server Url, eg: https://example.com/xxx',
|
||||||
|
'zh-Hans': '在此输入您的 Server Url,如:https://example.com/xxx',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
export default config
|
@ -41,6 +41,7 @@ export enum ProviderEnum {
|
|||||||
'chatglm' = 'chatglm',
|
'chatglm' = 'chatglm',
|
||||||
'xinference' = 'xinference',
|
'xinference' = 'xinference',
|
||||||
'openllm' = 'openllm',
|
'openllm' = 'openllm',
|
||||||
|
'localai' = 'localai',
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ProviderConfigItem = {
|
export type ProviderConfigItem = {
|
||||||
|
@ -99,6 +99,7 @@ const ModelPage = () => {
|
|||||||
config.chatglm,
|
config.chatglm,
|
||||||
config.xinference,
|
config.xinference,
|
||||||
config.openllm,
|
config.openllm,
|
||||||
|
config.localai,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations'
|
|||||||
import { ProviderEnum } from './declarations'
|
import { ProviderEnum } from './declarations'
|
||||||
import { validateModelProvider } from '@/service/common'
|
import { validateModelProvider } from '@/service/common'
|
||||||
|
|
||||||
export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm]
|
export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai]
|
||||||
|
|
||||||
export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => {
|
export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => {
|
||||||
let body, url
|
let body, url
|
||||||
|
Loading…
x
Reference in New Issue
Block a user