Fix/localai (#2840)

This commit is contained in:
Yeuoly 2024-03-15 11:41:51 +08:00 committed by GitHub
parent af98954fc1
commit 742be06ea9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 7 deletions

View File

@ -1,6 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast
from urllib.parse import urljoin
from httpx import Timeout from httpx import Timeout
from openai import ( from openai import (
@ -19,6 +18,7 @@ from openai import (
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion from openai.types.completion import Completion
from yarl import URL
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel):
UserPromptMessage(content='ping') UserPromptMessage(content='ping')
], model_parameters={ ], model_parameters={
'max_tokens': 10, 'max_tokens': 10,
}, stop=[]) }, stop=[], stream=False)
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel):
) )
] ]
model_properties = {
ModelPropertyKey.MODE: completion_model,
} if completion_model else {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel):
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {}, model_properties=model_properties,
parameter_rules=rules parameter_rules=rules
) )
@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel):
client_kwargs = { client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1", "api_key": "1",
"base_url": urljoin(credentials['server_url'], 'v1'), "base_url": str(URL(credentials['server_url']) / 'v1'),
} }
return client_kwargs return client_kwargs

View File

@ -56,3 +56,12 @@ model_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入LocalAI的服务器地址如 http://192.168.1.100:8080 zh_Hans: 在此输入LocalAI的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
placeholder:
zh_Hans: 输入上下文大小
en_US: Enter context size
required: false
type: text-input

View File

@ -1,11 +1,12 @@
import time import time
from json import JSONDecodeError, dumps from json import JSONDecodeError, dumps
from os.path import join
from typing import Optional from typing import Optional
from requests import post from requests import post
from yarl import URL
from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
@ -57,7 +58,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
} }
try: try:
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10) response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
except Exception as e: except Exception as e:
raise InvokeConnectionError(str(e)) raise InvokeConnectionError(str(e))
@ -114,6 +115,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
num_tokens += self._get_num_tokens_by_gpt2(text) num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens return num_tokens
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return AIModelEntity(
model=model,
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
features=[],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[]
)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
Validate model credentials Validate model credentials