From cefe156811eecbddf4dae39717a253d37e4431d8 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 26 Apr 2024 21:16:22 +0800 Subject: [PATCH] feat: replicate supports default version. (#3884) --- .../model_providers/replicate/llm/llm.py | 42 +++++++++++++------ .../model_providers/replicate/replicate.yaml | 6 +-- .../text_embedding/text_embedding.py | 25 +++++++---- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index ee2de85607..f4198dbfa7 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - version = credentials['model_version'] + model_version = '' + if 'model_version' in credentials: + model_version = credentials['model_version'] client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) model_info = client.models.get(model) - model_info_version = model_info.versions.get(version) + + if model_version: + model_info_version = model_info.versions.get(model_version) + else: + model_info_version = model_info.latest_version inputs = {**model_parameters} @@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): if 'replicate_api_token' not in credentials: raise CredentialsValidateFailedError('Replicate Access Token must be provided.') - if 'model_version' not in credentials: - raise CredentialsValidateFailedError('Replicate Model Version must be provided.') + model_version = '' + if 'model_version' in credentials: + model_version = credentials['model_version'] if model.count("/") != 1: raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' 'format: {user_name}/{model_name}') - version = credentials['model_version'] - try: client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) model_info = client.models.get(model) - model_info_version = model_info.versions.get(version) - self._check_text_generation_model(model_info_version, model, version) + if model_version: + model_info_version = model_info.versions.get(model_version) + else: + model_info_version = model_info.latest_version + + self._check_text_generation_model(model_info_version, model, model_version, model_info.description) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}") except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod - def _check_text_generation_model(model_info_version, model_name, version): + def _check_text_generation_model(model_info_version, model_name, version, description): + if 'language model' in description.lower(): + return + if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: @@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): @classmethod def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - version = credentials['model_version'] + model_version = '' + if 'model_version' in credentials: + model_version = credentials['model_version'] client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) model_info = client.models.get(model) - model_info_version = model_info.versions.get(version) + + if model_version: + model_info_version = model_info.versions.get(model_version) + else: + model_info_version = model_info.latest_version parameter_rules = [] diff --git a/api/core/model_runtime/model_providers/replicate/replicate.yaml b/api/core/model_runtime/model_providers/replicate/replicate.yaml index 11f615dad1..9cad6d4f0d 100644 --- a/api/core/model_runtime/model_providers/replicate/replicate.yaml +++ b/api/core/model_runtime/model_providers/replicate/replicate.yaml @@ -35,7 +35,7 @@ model_credential_schema: label: en_US: Model Version type: text-input - required: true + required: false placeholder: - zh_Hans: 在此输入您的模型版本 - en_US: Enter your model version + zh_Hans: 在此输入您的模型版本,默认为最新版本 + en_US: Enter your model version, default to the latest version diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index a481aebc99..0e4cdbf5bc 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -17,9 +17,16 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): user: Optional[str] = None) -> TextEmbeddingResult: client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - replicate_model_version = f'{model}:{credentials["model_version"]}' - text_input_key = self._get_text_input_key(model, credentials['model_version'], client) + if 'model_version' in credentials: + model_version = credentials['model_version'] + else: + model_info = client.models.get(model) + model_version = model_info.latest_version.id + + replicate_model_version = f'{model}:{model_version}' + + text_input_key = self._get_text_input_key(model, model_version, client) embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) @@ -43,14 +50,18 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): if 'replicate_api_token' not in credentials: raise CredentialsValidateFailedError('Replicate Access Token must be provided.') - if 'model_version' not in credentials: - raise CredentialsValidateFailedError('Replicate Model Version must be provided.') - try: client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - replicate_model_version = f'{model}:{credentials["model_version"]}' - text_input_key = self._get_text_input_key(model, credentials['model_version'], client) + if 'model_version' in credentials: + model_version = credentials['model_version'] + else: + model_info = client.models.get(model) + model_version = model_info.latest_version.id + + replicate_model_version = f'{model}:{model_version}' + + text_input_key = self._get_text_input_key(model, model_version, client) self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, ['Hello worlds!'])