mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 01:15:58 +08:00
Fix/localai (#2840)
This commit is contained in:
parent
af98954fc1
commit
742be06ea9
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
from httpx import Timeout
|
from httpx import Timeout
|
||||||
from openai import (
|
from openai import (
|
||||||
@ -19,6 +18,7 @@ from openai import (
|
|||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
from openai.types.completion import Completion
|
from openai.types.completion import Completion
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
UserPromptMessage(content='ping')
|
UserPromptMessage(content='ping')
|
||||||
], model_parameters={
|
], model_parameters={
|
||||||
'max_tokens': 10,
|
'max_tokens': 10,
|
||||||
}, stop=[])
|
}, stop=[], stream=False)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
||||||
|
|
||||||
@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
model_properties = {
|
||||||
|
ModelPropertyKey.MODE: completion_model,
|
||||||
|
} if completion_model else {}
|
||||||
|
|
||||||
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
|
model_properties=model_properties,
|
||||||
parameter_rules=rules
|
parameter_rules=rules
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
"api_key": "1",
|
"api_key": "1",
|
||||||
"base_url": urljoin(credentials['server_url'], 'v1'),
|
"base_url": str(URL(credentials['server_url']) / 'v1'),
|
||||||
}
|
}
|
||||||
|
|
||||||
return client_kwargs
|
return client_kwargs
|
||||||
|
@ -56,3 +56,12 @@ model_credential_schema:
|
|||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
|
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
|
||||||
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
|
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
|
||||||
|
- variable: context_size
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文大小
|
||||||
|
en_US: Context size
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 输入上下文大小
|
||||||
|
en_US: Enter context size
|
||||||
|
required: false
|
||||||
|
type: text-input
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from json import JSONDecodeError, dumps
|
from json import JSONDecodeError, dumps
|
||||||
from os.path import join
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from requests import post
|
from requests import post
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
@ -57,7 +58,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeConnectionError(str(e))
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
@ -114,6 +115,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
Get customizable model schema
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return: model schema
|
||||||
|
"""
|
||||||
|
return AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(zh_Hans=model, en_US=model),
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
features=[],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
|
||||||
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||||
|
},
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate model credentials
|
Validate model credentials
|
||||||
|
Loading…
x
Reference in New Issue
Block a user