feat: replicate supports default version. (#3884)

This commit is contained in:
Garfield Dai 2024-04-26 21:16:22 +08:00 committed by GitHub
parent 3b5b4d628b
commit cefe156811
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 22 deletions

View File

@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]: user: Optional[str] = None) -> Union[LLMResult, Generator]:
version = credentials['model_version'] model_version = ''
if 'model_version' in credentials:
model_version = credentials['model_version']
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model) model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)
if model_version:
model_info_version = model_info.versions.get(model_version)
else:
model_info_version = model_info.latest_version
inputs = {**model_parameters} inputs = {**model_parameters}
@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
if 'replicate_api_token' not in credentials: if 'replicate_api_token' not in credentials:
raise CredentialsValidateFailedError('Replicate Access Token must be provided.') raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
if 'model_version' not in credentials: model_version = ''
raise CredentialsValidateFailedError('Replicate Model Version must be provided.') if 'model_version' in credentials:
model_version = credentials['model_version']
if model.count("/") != 1: if model.count("/") != 1:
raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
'format: {user_name}/{model_name}') 'format: {user_name}/{model_name}')
version = credentials['model_version']
try: try:
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model) model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)
self._check_text_generation_model(model_info_version, model, version) if model_version:
model_info_version = model_info.versions.get(model_version)
else:
model_info_version = model_info.latest_version
self._check_text_generation_model(model_info_version, model, model_version, model_info.description)
except ReplicateError as e: except ReplicateError as e:
raise CredentialsValidateFailedError( raise CredentialsValidateFailedError(
f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}") f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}")
except Exception as e: except Exception as e:
raise CredentialsValidateFailedError(str(e)) raise CredentialsValidateFailedError(str(e))
@staticmethod @staticmethod
def _check_text_generation_model(model_info_version, model_name, version): def _check_text_generation_model(model_info_version, model_name, version, description):
if 'language model' in description.lower():
return
if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']:
@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
@classmethod @classmethod
def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
version = credentials['model_version'] model_version = ''
if 'model_version' in credentials:
model_version = credentials['model_version']
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
model_info = client.models.get(model) model_info = client.models.get(model)
model_info_version = model_info.versions.get(version)
if model_version:
model_info_version = model_info.versions.get(model_version)
else:
model_info_version = model_info.latest_version
parameter_rules = [] parameter_rules = []

View File

@ -35,7 +35,7 @@ model_credential_schema:
label: label:
en_US: Model Version en_US: Model Version
type: text-input type: text-input
required: true required: false
placeholder: placeholder:
zh_Hans: 在此输入您的模型版本 zh_Hans: 在此输入您的模型版本,默认为最新版本
en_US: Enter your model version en_US: Enter your model version, default to the latest version

View File

@ -17,9 +17,16 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
user: Optional[str] = None) -> TextEmbeddingResult: user: Optional[str] = None) -> TextEmbeddingResult:
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
replicate_model_version = f'{model}:{credentials["model_version"]}'
text_input_key = self._get_text_input_key(model, credentials['model_version'], client) if 'model_version' in credentials:
model_version = credentials['model_version']
else:
model_info = client.models.get(model)
model_version = model_info.latest_version.id
replicate_model_version = f'{model}:{model_version}'
text_input_key = self._get_text_input_key(model, model_version, client)
embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
texts) texts)
@ -43,14 +50,18 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
if 'replicate_api_token' not in credentials: if 'replicate_api_token' not in credentials:
raise CredentialsValidateFailedError('Replicate Access Token must be provided.') raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
if 'model_version' not in credentials:
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
try: try:
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
replicate_model_version = f'{model}:{credentials["model_version"]}'
text_input_key = self._get_text_input_key(model, credentials['model_version'], client) if 'model_version' in credentials:
model_version = credentials['model_version']
else:
model_info = client.models.get(model)
model_version = model_info.latest_version.id
replicate_model_version = f'{model}:{model_version}'
text_input_key = self._get_text_input_key(model, model_version, client)
self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
['Hello worlds!']) ['Hello worlds!'])