From 920fb6d0e1a7e3f2d48ee09bf4aeb3775264c0a8 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 19 Aug 2023 16:54:08 +0800 Subject: [PATCH] fix: embedding price config (#918) --- .../model_providers/models/embedding/base.py | 10 ++++----- .../models/embedding/minimax_embedding.py | 6 ------ .../model_providers/rules/azure_openai.json | 4 ++-- api/core/model_providers/rules/minimax.json | 21 ++++++++++++++++++- api/core/model_providers/rules/spark.json | 16 +++++++++++++- 5 files changed, 41 insertions(+), 16 deletions(-) diff --git a/api/core/model_providers/models/embedding/base.py b/api/core/model_providers/models/embedding/base.py index 92cfd02f32..1cbe56a9f3 100644 --- a/api/core/model_providers/models/embedding/base.py +++ b/api/core/model_providers/models/embedding/base.py @@ -32,7 +32,6 @@ class BaseEmbedding(BaseProviderModel): def price_config(self) -> dict: def get_or_default(): default_price_config = { - 'prompt': decimal.Decimal('0'), 'completion': decimal.Decimal('0'), 'unit': decimal.Decimal('0'), 'currency': 'USD' @@ -40,7 +39,6 @@ class BaseEmbedding(BaseProviderModel): rules = self.model_provider.get_rules() price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config price_config = { - 'prompt': decimal.Decimal(price_config['prompt']), 'completion': decimal.Decimal(price_config['completion']), 'unit': decimal.Decimal(price_config['unit']), 'currency': price_config['currency'] @@ -59,8 +57,8 @@ class BaseEmbedding(BaseProviderModel): :param tokens: :return: decimal.Decimal('0.0000001') """ - unit_price = self._price_config['completion'] - unit = self._price_config['unit'] + unit_price = self.price_config['completion'] + unit = self.price_config['unit'] total_price = tokens * unit_price * unit total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") @@ -73,7 +71,7 @@ class BaseEmbedding(BaseProviderModel): :return: decimal.Decimal('0.0001') """ - unit_price = self._price_config['completion'] + unit_price = self.price_config['completion'] unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP) logger.debug(f'unit_price:{unit_price}') return unit_price @@ -96,7 +94,7 @@ class BaseEmbedding(BaseProviderModel): :return: get from price config, default 'USD' """ - currency = self._price_config['currency'] + currency = self.price_config['currency'] return currency @abstractmethod diff --git a/api/core/model_providers/models/embedding/minimax_embedding.py b/api/core/model_providers/models/embedding/minimax_embedding.py index 185c66ab76..690ca9946f 100644 --- a/api/core/model_providers/models/embedding/minimax_embedding.py +++ b/api/core/model_providers/models/embedding/minimax_embedding.py @@ -1,6 +1,3 @@ -import decimal -import logging - from langchain.embeddings import MiniMaxEmbeddings from core.model_providers.error import LLMBadRequestError @@ -22,9 +19,6 @@ class MinimaxEmbedding(BaseEmbedding): super().__init__(model_provider, client, name) - def get_currency(self): - return 'RMB' - def handle_exceptions(self, ex: Exception) -> Exception: if isinstance(ex, ValueError): return LLMBadRequestError(f"Minimax: {str(ex)}") diff --git a/api/core/model_providers/rules/azure_openai.json b/api/core/model_providers/rules/azure_openai.json index dfb354d5a4..fe4dc10c56 100644 --- a/api/core/model_providers/rules/azure_openai.json +++ b/api/core/model_providers/rules/azure_openai.json @@ -18,8 +18,8 @@ "currency": "USD" }, "gpt-35-turbo": { - "prompt": "0.0015", - "completion": "0.002", + "prompt": "0.002", + "completion": "0.0015", "unit": "0.001", "currency": "USD" }, diff --git a/api/core/model_providers/rules/minimax.json b/api/core/model_providers/rules/minimax.json index e19b885a25..765d6712e1 100644 --- a/api/core/model_providers/rules/minimax.json +++ b/api/core/model_providers/rules/minimax.json @@ -9,5 +9,24 @@ ], "quota_unit": "tokens" }, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "price_config": { + "abab5.5-chat": { + "prompt": "0.015", + "completion": "0.015", + "unit": "0.001", + "currency": "RMB" + }, + "abab5-chat": { + "prompt": "0.015", + "completion": "0.015", + "unit": "0.001", + "currency": "RMB" + }, + "embo-01": { + "completion": "0", + "unit": "0.0001", + "currency": "RMB" + } + } } \ No newline at end of file diff --git a/api/core/model_providers/rules/spark.json b/api/core/model_providers/rules/spark.json index e19b885a25..a3a01ae4a5 100644 --- a/api/core/model_providers/rules/spark.json +++ b/api/core/model_providers/rules/spark.json @@ -9,5 +9,19 @@ ], "quota_unit": "tokens" }, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "price_config": { + "spark": { + "prompt": "0.18", + "completion": "0.18", + "unit": "0.0001", + "currency": "RMB" + }, + "spark-v2": { + "prompt": "0.36", + "completion": "0.36", + "unit": "0.0001", + "currency": "RMB" + } + } } \ No newline at end of file