mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 02:29:03 +08:00
fix: xinference cache (#1926)
This commit is contained in:
parent
01f9feff9f
commit
5a756ca981
@ -236,16 +236,6 @@ class AIModel(ABC):
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
if 'schema' in credentials:
|
||||
schema_dict = json.loads(credentials['schema'])
|
||||
|
||||
try:
|
||||
model_instance = AIModelEntity.parse_obj(schema_dict)
|
||||
return model_instance
|
||||
except ValidationError as e:
|
||||
logging.exception(f"Invalid model schema for {model}")
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
completion_model = None
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
completion_model = LLMMode.CHAT
|
||||
completion_model = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
completion_model = LLMMode.COMPLETION
|
||||
completion_model = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={ 'mode': completion_model } if completion_model else {},
|
||||
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
|
@ -117,9 +117,9 @@ class _CommonOAI_API_Compat:
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
if credentials['mode'] == 'chat':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif credentials['mode'] == 'completion':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
|
@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
|
||||
@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
'mode': LLMMode.COMPLETION,
|
||||
ModelPropertyKey.MODE: LLMMode.COMPLETION.value,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
|
||||
PromptMessageRole, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.replicate._common import _CommonReplicate
|
||||
@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
'mode': model_type.value
|
||||
ModelPropertyKey.MODE: model_type.value
|
||||
},
|
||||
parameter_rules=self._get_customizable_model_parameter_rules(model, credentials)
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
XinferenceHelper.get_xinference_extra_parameter(
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
)
|
||||
if 'completion_type' not in credentials:
|
||||
if 'chat' in extra_param.model_ability:
|
||||
credentials['completion_type'] = 'chat'
|
||||
elif 'generate' in extra_param.model_ability:
|
||||
credentials['completion_type'] = 'completion'
|
||||
else:
|
||||
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
|
||||
|
||||
except RuntimeError as e:
|
||||
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
||||
except KeyError as e:
|
||||
@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
]
|
||||
|
||||
completion_type = None
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
)
|
||||
|
||||
if 'chat' in extra_args.model_ability:
|
||||
completion_type = LLMMode.CHAT
|
||||
elif 'generate' in extra_args.model_ability:
|
||||
completion_type = LLMMode.COMPLETION
|
||||
if 'completion_type' in credentials:
|
||||
if credentials['completion_type'] == 'chat':
|
||||
completion_type = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
else:
|
||||
raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
)
|
||||
|
||||
if 'chat' in extra_args.model_ability:
|
||||
completion_type = LLMMode.CHAT.value
|
||||
elif 'generate' in extra_args.model_ability:
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
'mode': completion_type,
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user