mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 19:05:55 +08:00
feat: replicate supports default version. (#3884)
This commit is contained in:
parent
3b5b4d628b
commit
cefe156811
@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
|
||||
version = credentials['model_version']
|
||||
model_version = ''
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(version)
|
||||
|
||||
if model_version:
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
else:
|
||||
model_info_version = model_info.latest_version
|
||||
|
||||
inputs = {**model_parameters}
|
||||
|
||||
@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
if 'replicate_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
|
||||
|
||||
if 'model_version' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
||||
model_version = ''
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
|
||||
if model.count("/") != 1:
|
||||
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
|
||||
'format: {user_name}/{model_name}')
|
||||
|
||||
version = credentials['model_version']
|
||||
|
||||
try:
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(version)
|
||||
|
||||
self._check_text_generation_model(model_info_version, model, version)
|
||||
if model_version:
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
else:
|
||||
model_info_version = model_info.latest_version
|
||||
|
||||
self._check_text_generation_model(model_info_version, model, model_version, model_info.description)
|
||||
except ReplicateError as e:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
@staticmethod
|
||||
def _check_text_generation_model(model_info_version, model_name, version):
|
||||
def _check_text_generation_model(model_info_version, model_name, version, description):
|
||||
if 'language model' in description.lower():
|
||||
return
|
||||
|
||||
if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
|
||||
or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
|
||||
or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']:
|
||||
@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
|
||||
@classmethod
|
||||
def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
|
||||
version = credentials['model_version']
|
||||
model_version = ''
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
model_info = client.models.get(model)
|
||||
model_info_version = model_info.versions.get(version)
|
||||
|
||||
if model_version:
|
||||
model_info_version = model_info.versions.get(model_version)
|
||||
else:
|
||||
model_info_version = model_info.latest_version
|
||||
|
||||
parameter_rules = []
|
||||
|
||||
|
@ -35,7 +35,7 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: Model Version
|
||||
type: text-input
|
||||
required: true
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
||||
zh_Hans: 在此输入您的模型版本,默认为最新版本
|
||||
en_US: Enter your model version, default to the latest version
|
||||
|
@ -17,9 +17,16 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
replicate_model_version = f'{model}:{credentials["model_version"]}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
else:
|
||||
model_info = client.models.get(model)
|
||||
model_version = model_info.latest_version.id
|
||||
|
||||
replicate_model_version = f'{model}:{model_version}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, model_version, client)
|
||||
|
||||
embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
|
||||
texts)
|
||||
@ -43,14 +50,18 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
if 'replicate_api_token' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
|
||||
|
||||
if 'model_version' not in credentials:
|
||||
raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
||||
|
||||
try:
|
||||
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
||||
replicate_model_version = f'{model}:{credentials["model_version"]}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
|
||||
if 'model_version' in credentials:
|
||||
model_version = credentials['model_version']
|
||||
else:
|
||||
model_info = client.models.get(model)
|
||||
model_version = model_info.latest_version.id
|
||||
|
||||
replicate_model_version = f'{model}:{model_version}'
|
||||
|
||||
text_input_key = self._get_text_input_key(model, model_version, client)
|
||||
|
||||
self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
|
||||
['Hello worlds!'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user