diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index b312d99b1c..64b7341e9f 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -1,9 +1,23 @@ from collections.abc import Generator +from decimal import Decimal from typing import Optional, Union -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel @@ -36,8 +50,98 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) + REPETITION_PENALTY = "repetition_penalty" + TOP_K = "top_k" + features = [] - return super().get_customizable_model_schema(model, cred_with_endpoint) + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")), + ModelPropertyKey.MODE: cred_with_endpoint.get('mode'), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=float(cred_with_endpoint.get('temperature', 0.7)), + min=0, + max=2, + precision=2 + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + default=float(cred_with_endpoint.get('top_p', 1)), + min=0, + max=1, + precision=2 + ), + ParameterRule( + name=TOP_K, + label=I18nObject(en_US="Top K"), + type=ParameterType.INT, + default=int(cred_with_endpoint.get('top_k', 50)), + min=-2147483647, + max=2147483647, + precision=0 + ), + ParameterRule( + name=REPETITION_PENALTY, + label=I18nObject(en_US="Repetition Penalty"), + type=ParameterType.FLOAT, + default=float(cred_with_endpoint.get('repetition_penalty', 1)), + min=-3.4, + max=3.4, + precision=1 + ), + ParameterRule( + name=DefaultParameterName.MAX_TOKENS.value, + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + default=512, + min=1, + max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)), + ), + ParameterRule( + name=DefaultParameterName.FREQUENCY_PENALTY.value, + label=I18nObject(en_US="Frequency Penalty"), + type=ParameterType.FLOAT, + default=float(credentials.get('frequency_penalty', 0)), + min=-2, + max=2 + ), + ParameterRule( + name=DefaultParameterName.PRESENCE_PENALTY.value, + label=I18nObject(en_US="Presence Penalty"), + type=ParameterType.FLOAT, + default=float(credentials.get('presence_penalty', 0)), + min=-2, + max=2 + ) + ], + pricing=PriceConfig( + input=Decimal(cred_with_endpoint.get('input_price', 0)), + output=Decimal(cred_with_endpoint.get('output_price', 0)), + unit=Decimal(cred_with_endpoint.get('unit', 0)), + currency=cred_with_endpoint.get('currency', "USD") + ), + ) + + if cred_with_endpoint['mode'] == 'chat': + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + elif cred_with_endpoint['mode'] == 'completion': + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value + else: + raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}") + + return entity def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: