diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index c5b726237f..049c9d0622 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -236,16 +236,6 @@ class AIModel(ABC): :param credentials: model credentials :return: model schema """ - if 'schema' in credentials: - schema_dict = json.loads(credentials['schema']) - - try: - model_instance = AIModelEntity.parse_obj(schema_dict) - return model_instance - except ValidationError as e: - logging.exception(f"Invalid model schema for {model}") - return self._get_customizable_model_schema(model, credentials) - return self._get_customizable_model_schema(model, credentials) def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 4a845ac44f..e9348fa114 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,7 +1,7 @@ from typing import Generator, List, Optional, Union, cast from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ @@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: completion_model = None if credentials['completion_type'] == 'chat_completion': - completion_model = LLMMode.CHAT + completion_model = LLMMode.CHAT.value elif credentials['completion_type'] == 'completion': - completion_model = LLMMode.COMPLETION + completion_model = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") @@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ 'mode': completion_model } if completion_model else {}, + model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {}, parameter_rules=rules ) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 6dcdaad53e..e4b78abefa 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -117,9 +117,9 @@ class _CommonOAI_API_Compat: if model_type == ModelType.LLM: if credentials['mode'] == 'chat': - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value elif credentials['mode'] == 'completion': - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 58a4517e4e..609ea19b59 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ InvokeAuthorizationError, InvokeBadRequestError, InvokeError @@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ - 'mode': LLMMode.COMPLETION, + ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, parameter_rules=rules ) diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index cd750375be..54134feca9 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \ PromptMessageRole, UserPromptMessage, SystemPromptMessage -from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.replicate._common import _CommonReplicate @@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ - 'mode': model_type.value + ModelPropertyKey.MODE: model_type.value }, parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) ) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index ff39f2a1d7..c32d9a3d8e 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType +from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ @@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): } """ try: - XinferenceHelper.get_xinference_extra_parameter( + extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], model_uid=credentials['model_uid'] ) + if 'completion_type' not in credentials: + if 'chat' in extra_param.model_ability: + credentials['completion_type'] = 'chat' + elif 'generate' in extra_param.model_ability: + credentials['completion_type'] = 'completion' + else: + raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') + except RuntimeError as e: raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') except KeyError as e: @@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ] completion_type = None - extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] - ) - if 'chat' in extra_args.model_ability: - completion_type = LLMMode.CHAT - elif 'generate' in extra_args.model_ability: - completion_type = LLMMode.COMPLETION + if 'completion_type' in credentials: + if credentials['completion_type'] == 'chat': + completion_type = LLMMode.CHAT.value + elif credentials['completion_type'] == 'completion': + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') else: - raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported') + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=credentials['server_url'], + model_uid=credentials['model_uid'] + ) + + if 'chat' in extra_args.model_ability: + completion_type = LLMMode.CHAT.value + elif 'generate' in extra_args.model_ability: + completion_type = LLMMode.COMPLETION.value + else: + raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') entity = AIModelEntity( model=model, @@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ - 'mode': completion_type, + ModelPropertyKey.MODE: completion_type, }, parameter_rules=rules )