From fd0fc8f4fe810b697cf8f95c4b1309ca61918f51 Mon Sep 17 00:00:00 2001 From: "Krasus.Chen" Date: Sat, 19 Aug 2023 16:41:35 +0800 Subject: [PATCH] Fix/price calc (#862) --- api/core/conversation_message_task.py | 31 +++----- api/core/indexing_runner.py | 6 +- .../embedding/azure_openai_embedding.py | 19 +++-- .../model_providers/models/embedding/base.py | 74 +++++++++++++++++-- .../models/embedding/minimax_embedding.py | 3 - .../models/embedding/openai_embedding.py | 10 --- .../models/embedding/replicate_embedding.py | 7 -- .../models/llm/anthropic_model.py | 26 ------- .../models/llm/azure_openai_model.py | 49 +++--------- api/core/model_providers/models/llm/base.py | 73 ++++++++++++++++-- .../models/llm/chatglm_model.py | 3 - .../models/llm/huggingface_hub_model.py | 7 -- .../models/llm/minimax_model.py | 3 - .../models/llm/openai_model.py | 41 +--------- .../models/llm/replicate_model.py | 7 -- .../model_providers/models/llm/spark_model.py | 3 - .../models/llm/tongyi_model.py | 3 - .../models/llm/wenxin_model.py | 31 +------- api/core/model_providers/rules/anthropic.json | 16 +++- .../model_providers/rules/azure_openai.json | 45 ++++++++++- api/core/model_providers/rules/openai.json | 39 +++++++++- api/core/model_providers/rules/wenxin.json | 22 +++++- 22 files changed, 288 insertions(+), 230 deletions(-) diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 099c19be27..df06101e4d 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -140,10 +140,13 @@ class ConversationMessageTask: def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): message_tokens = llm_message.prompt_tokens answer_tokens = llm_message.completion_tokens - message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN) - answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT) - total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) + message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) + answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) + + message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) + answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) + total_price = message_total_price + answer_total_price self.message.message = llm_message.prompt self.message.message_tokens = message_tokens @@ -206,18 +209,15 @@ class ConversationMessageTask: def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, agent_loop: AgentLoop): - agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) - agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) + agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) + agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) loop_message_tokens = agent_loop.prompt_tokens loop_answer_tokens = agent_loop.completion_tokens - loop_total_price = self.calc_total_price( - loop_message_tokens, - agent_message_unit_price, - loop_answer_tokens, - agent_answer_unit_price - ) + loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) + loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) + loop_total_price = loop_message_total_price + loop_answer_total_price message_agent_thought.observation = agent_loop.tool_output message_agent_thought.tool_process_data = '' # currently not support @@ -243,15 +243,6 @@ class ConversationMessageTask: db.session.add(dataset_query) - def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price): - message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - def end(self): self._pub_handler.pub_end() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 21a678fef9..2f47f52705 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -278,7 +278,7 @@ class IndexingRunner: "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( - text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), + text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), "currency": embedding_model.get_currency(), "qa_preview": document_qa_list, "preview": preview_texts @@ -286,7 +286,7 @@ class IndexingRunner: return { "total_segments": total_segments, "tokens": tokens, - "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), + "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), "currency": embedding_model.get_currency(), "preview": preview_texts } @@ -371,7 +371,7 @@ class IndexingRunner: "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( - text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), + text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), "currency": embedding_model.get_currency(), "qa_preview": document_qa_list, "preview": preview_texts diff --git a/api/core/model_providers/models/embedding/azure_openai_embedding.py b/api/core/model_providers/models/embedding/azure_openai_embedding.py index 1b9149392b..506b32c557 100644 --- a/api/core/model_providers/models/embedding/azure_openai_embedding.py +++ b/api/core/model_providers/models/embedding/azure_openai_embedding.py @@ -31,6 +31,15 @@ class AzureOpenAIEmbedding(BaseEmbedding): ) super().__init__(model_provider, client, name) + + @property + def base_model_name(self) -> str: + """ + get base model name (not deployment) + + :return: str + """ + return self.credentials.get("base_model_name") def get_num_tokens(self, text: str) -> int: """ @@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding): # calculate the number of tokens in the encoded text return len(tokenized_text) - def get_token_price(self, tokens: int): - tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1k * decimal.Decimal('0.0001') - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - - def get_currency(self): - return 'USD' - def handle_exceptions(self, ex: Exception) -> Exception: if isinstance(ex, openai.error.InvalidRequestError): logging.warning("Invalid request to Azure OpenAI API.") diff --git a/api/core/model_providers/models/embedding/base.py b/api/core/model_providers/models/embedding/base.py index fc42d88bcd..92cfd02f32 100644 --- a/api/core/model_providers/models/embedding/base.py +++ b/api/core/model_providers/models/embedding/base.py @@ -1,5 +1,6 @@ from abc import abstractmethod from typing import Any +import decimal import tiktoken from langchain.schema.language_model import _get_token_ids_default_method @@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.entity.model_params import ModelType from core.model_providers.providers.base import BaseModelProvider - +import logging +logger = logging.getLogger(__name__) class BaseEmbedding(BaseProviderModel): name: str @@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel): super().__init__(model_provider, client) self.name = name + @property + def base_model_name(self) -> str: + """ + get base model name + + :return: str + """ + return self.name + + @property + def price_config(self) -> dict: + def get_or_default(): + default_price_config = { + 'prompt': decimal.Decimal('0'), + 'completion': decimal.Decimal('0'), + 'unit': decimal.Decimal('0'), + 'currency': 'USD' + } + rules = self.model_provider.get_rules() + price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config + price_config = { + 'prompt': decimal.Decimal(price_config['prompt']), + 'completion': decimal.Decimal(price_config['completion']), + 'unit': decimal.Decimal(price_config['unit']), + 'currency': price_config['currency'] + } + return price_config + + self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default() + + logger.debug(f"model: {self.name} price_config: {self._price_config}") + return self._price_config + + def calc_tokens_price(self, tokens:int) -> decimal.Decimal: + """ + calc tokens total price. + + :param tokens: + :return: decimal.Decimal('0.0000001') + """ + unit_price = self._price_config['completion'] + unit = self._price_config['unit'] + total_price = tokens * unit_price * unit + total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") + return total_price + + def get_tokens_unit_price(self) -> decimal.Decimal: + """ + get token price. + + :return: decimal.Decimal('0.0001') + + """ + unit_price = self._price_config['completion'] + unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP) + logger.debug(f'unit_price:{unit_price}') + return unit_price + def get_num_tokens(self, text: str) -> int: """ get num tokens of text. @@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel): return len(_get_token_ids_default_method(text)) - def get_token_price(self, tokens: int): - return 0 - def get_currency(self): - return 'USD' + """ + get token currency. + + :return: get from price config, default 'USD' + """ + currency = self._price_config['currency'] + return currency @abstractmethod def handle_exceptions(self, ex: Exception) -> Exception: diff --git a/api/core/model_providers/models/embedding/minimax_embedding.py b/api/core/model_providers/models/embedding/minimax_embedding.py index d8cb22f347..185c66ab76 100644 --- a/api/core/model_providers/models/embedding/minimax_embedding.py +++ b/api/core/model_providers/models/embedding/minimax_embedding.py @@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding): super().__init__(model_provider, client, name) - def get_token_price(self, tokens: int): - return decimal.Decimal('0') - def get_currency(self): return 'RMB' diff --git a/api/core/model_providers/models/embedding/openai_embedding.py b/api/core/model_providers/models/embedding/openai_embedding.py index 3ab65c291d..54444b121c 100644 --- a/api/core/model_providers/models/embedding/openai_embedding.py +++ b/api/core/model_providers/models/embedding/openai_embedding.py @@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding): # calculate the number of tokens in the encoded text return len(tokenized_text) - def get_token_price(self, tokens: int): - tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1k * decimal.Decimal('0.0001') - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - - def get_currency(self): - return 'USD' - def handle_exceptions(self, ex: Exception) -> Exception: if isinstance(ex, openai.error.InvalidRequestError): logging.warning("Invalid request to OpenAI API.") diff --git a/api/core/model_providers/models/embedding/replicate_embedding.py b/api/core/model_providers/models/embedding/replicate_embedding.py index 3f7ef2851d..962593fcdb 100644 --- a/api/core/model_providers/models/embedding/replicate_embedding.py +++ b/api/core/model_providers/models/embedding/replicate_embedding.py @@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding): super().__init__(model_provider, client, name) - def get_token_price(self, tokens: int): - # replicate only pay for prediction seconds - return decimal.Decimal('0') - - def get_currency(self): - return 'USD' - def handle_exceptions(self, ex: Exception) -> Exception: if isinstance(ex, (ModelError, ReplicateError)): return LLMBadRequestError(f"Replicate: {str(ex)}") diff --git a/api/core/model_providers/models/llm/anthropic_model.py b/api/core/model_providers/models/llm/anthropic_model.py index 69dd76611f..dd6c17798d 100644 --- a/api/core/model_providers/models/llm/anthropic_model.py +++ b/api/core/model_providers/models/llm/anthropic_model.py @@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - model_unit_prices = { - 'claude-instant-1': { - 'prompt': decimal.Decimal('1.63'), - 'completion': decimal.Decimal('5.51'), - }, - 'claude-2': { - 'prompt': decimal.Decimal('11.02'), - 'completion': decimal.Decimal('32.68'), - }, - } - - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: - unit_price = model_unit_prices[self.name]['prompt'] - else: - unit_price = model_unit_prices[self.name]['completion'] - - tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1m * unit_price - return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP) - - def get_currency(self): - return 'USD' - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) for k, v in provider_model_kwargs.items(): diff --git a/api/core/model_providers/models/llm/azure_openai_model.py b/api/core/model_providers/models/llm/azure_openai_model.py index 66acaaeaa1..a5a0d13d99 100644 --- a/api/core/model_providers/models/llm/azure_openai_model.py +++ b/api/core/model_providers/models/llm/azure_openai_model.py @@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM): self.model_mode = ModelMode.COMPLETION else: self.model_mode = ModelMode.CHAT - super().__init__(model_provider, name, model_kwargs, streaming, callbacks) def _init_client(self) -> Any: @@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM): """ prompts = self._get_prompt_from_messages(messages) return self._client.generate([prompts], stop, callbacks) + + @property + def base_model_name(self) -> str: + """ + get base model name (not deployment) + + :return: str + """ + return self.credentials.get("base_model_name") def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ @@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM): else: return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - model_unit_prices = { - 'gpt-4': { - 'prompt': decimal.Decimal('0.03'), - 'completion': decimal.Decimal('0.06'), - }, - 'gpt-4-32k': { - 'prompt': decimal.Decimal('0.06'), - 'completion': decimal.Decimal('0.12') - }, - 'gpt-35-turbo': { - 'prompt': decimal.Decimal('0.0015'), - 'completion': decimal.Decimal('0.002') - }, - 'gpt-35-turbo-16k': { - 'prompt': decimal.Decimal('0.003'), - 'completion': decimal.Decimal('0.004') - }, - 'text-davinci-003': { - 'prompt': decimal.Decimal('0.02'), - 'completion': decimal.Decimal('0.02') - }, - } - - base_model_name = self.credentials.get("base_model_name") - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: - unit_price = model_unit_prices[base_model_name]['prompt'] - else: - unit_price = model_unit_prices[base_model_name]['completion'] - - tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1k * unit_price - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - - def get_currency(self): - return 'USD' - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) if self.name == 'text-davinci-003': diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index a00ea87504..6b20098be3 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -1,5 +1,6 @@ from abc import abstractmethod from typing import List, Optional, Any, Union +import decimal from langchain.callbacks.manager import Callbacks from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration @@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.providers.base import BaseModelProvider from core.third_party.langchain.llms.fake import FakeLLM +import logging +logger = logging.getLogger(__name__) class BaseLLM(BaseProviderModel): @@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel): def _init_client(self) -> Any: raise NotImplementedError + @property + def base_model_name(self) -> str: + """ + get llm base model name + + :return: str + """ + return self.name + + @property + def price_config(self) -> dict: + def get_or_default(): + default_price_config = { + 'prompt': decimal.Decimal('0'), + 'completion': decimal.Decimal('0'), + 'unit': decimal.Decimal('0'), + 'currency': 'USD' + } + rules = self.model_provider.get_rules() + price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config + price_config = { + 'prompt': decimal.Decimal(price_config['prompt']), + 'completion': decimal.Decimal(price_config['completion']), + 'unit': decimal.Decimal(price_config['unit']), + 'currency': price_config['currency'] + } + return price_config + + self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default() + + logger.debug(f"model: {self.name} price_config: {self._price_config}") + return self._price_config + def run(self, messages: List[PromptMessage], stop: Optional[List[str]] = None, callbacks: Callbacks = None, @@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel): """ raise NotImplementedError - @abstractmethod - def get_token_price(self, tokens: int, message_type: MessageType): + def calc_tokens_price(self, tokens:int, message_type: MessageType): """ - get token price. + calc tokens total price. :param tokens: :param message_type: :return: """ - raise NotImplementedError + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + unit_price = self.price_config['prompt'] + else: + unit_price = self.price_config['completion'] + unit = self.price_config['unit'] + + total_price = tokens * unit_price * unit + total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") + return total_price + + def get_tokens_unit_price(self, message_type: MessageType): + """ + get token price. + + :param message_type: + :return: decimal.Decimal('0.0001') + """ + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + unit_price = self.price_config['prompt'] + else: + unit_price = self.price_config['completion'] + unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP) + logging.debug(f"unit_price={unit_price}") + return unit_price - @abstractmethod def get_currency(self): """ get token currency. - :return: + :return: get from price config, default 'USD' """ - raise NotImplementedError + currency = self.price_config['currency'] + return currency def get_model_kwargs(self): return self.model_kwargs diff --git a/api/core/model_providers/models/llm/chatglm_model.py b/api/core/model_providers/models/llm/chatglm_model.py index 42036dbfdd..f3ce9ceaf0 100644 --- a/api/core/model_providers/models/llm/chatglm_model.py +++ b/api/core/model_providers/models/llm/chatglm_model.py @@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - return decimal.Decimal('0') - def get_currency(self): return 'RMB' diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py index f5deded517..16aec70c30 100644 --- a/api/core/model_providers/models/llm/huggingface_hub_model.py +++ b/api/core/model_providers/models/llm/huggingface_hub_model.py @@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return self._client.get_num_tokens(prompts) - def get_token_price(self, tokens: int, message_type: MessageType): - # not support calc price - return decimal.Decimal('0') - - def get_currency(self): - return 'USD' - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) self.client.model_kwargs = provider_model_kwargs diff --git a/api/core/model_providers/models/llm/minimax_model.py b/api/core/model_providers/models/llm/minimax_model.py index b7e38462f0..e2252d7edc 100644 --- a/api/core/model_providers/models/llm/minimax_model.py +++ b/api/core/model_providers/models/llm/minimax_model.py @@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - return decimal.Decimal('0') - def get_currency(self): return 'RMB' diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py index e7026c0241..91db37df6f 100644 --- a/api/core/model_providers/models/llm/openai_model.py +++ b/api/core/model_providers/models/llm/openai_model.py @@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM): self.model_mode = ModelMode.COMPLETION else: self.model_mode = ModelMode.CHAT - + + # TODO load price config from configs(db) super().__init__(model_provider, name, model_kwargs, streaming, callbacks) def _init_client(self) -> Any: @@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM): else: return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - model_unit_prices = { - 'gpt-4': { - 'prompt': decimal.Decimal('0.03'), - 'completion': decimal.Decimal('0.06'), - }, - 'gpt-4-32k': { - 'prompt': decimal.Decimal('0.06'), - 'completion': decimal.Decimal('0.12') - }, - 'gpt-3.5-turbo': { - 'prompt': decimal.Decimal('0.0015'), - 'completion': decimal.Decimal('0.002') - }, - 'gpt-3.5-turbo-16k': { - 'prompt': decimal.Decimal('0.003'), - 'completion': decimal.Decimal('0.004') - }, - 'text-davinci-003': { - 'prompt': decimal.Decimal('0.02'), - 'completion': decimal.Decimal('0.02') - }, - } - - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: - unit_price = model_unit_prices[self.name]['prompt'] - else: - unit_price = model_unit_prices[self.name]['completion'] - - tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1k * unit_price - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - - def get_currency(self): - return 'USD' - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) if self.name in COMPLETION_MODELS: diff --git a/api/core/model_providers/models/llm/replicate_model.py b/api/core/model_providers/models/llm/replicate_model.py index 7dd7eb8531..e740440ac2 100644 --- a/api/core/model_providers/models/llm/replicate_model.py +++ b/api/core/model_providers/models/llm/replicate_model.py @@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM): return self._client.get_num_tokens(prompts) - def get_token_price(self, tokens: int, message_type: MessageType): - # replicate only pay for prediction seconds - return decimal.Decimal('0') - - def get_currency(self): - return 'USD' - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) self.client.input = provider_model_kwargs diff --git a/api/core/model_providers/models/llm/spark_model.py b/api/core/model_providers/models/llm/spark_model.py index a16318637c..a7b63ae058 100644 --- a/api/core/model_providers/models/llm/spark_model.py +++ b/api/core/model_providers/models/llm/spark_model.py @@ -50,9 +50,6 @@ class SparkModel(BaseLLM): contents = [message.content for message in messages] return max(self._client.get_num_tokens("".join(contents)), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - return decimal.Decimal('0') - def get_currency(self): return 'RMB' diff --git a/api/core/model_providers/models/llm/tongyi_model.py b/api/core/model_providers/models/llm/tongyi_model.py index f950275f77..7138338a0c 100644 --- a/api/core/model_providers/models/llm/tongyi_model.py +++ b/api/core/model_providers/models/llm/tongyi_model.py @@ -53,9 +53,6 @@ class TongyiModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - return decimal.Decimal('0') - def get_currency(self): return 'RMB' diff --git a/api/core/model_providers/models/llm/wenxin_model.py b/api/core/model_providers/models/llm/wenxin_model.py index 2c950679ab..0f42ad27b5 100644 --- a/api/core/model_providers/models/llm/wenxin_model.py +++ b/api/core/model_providers/models/llm/wenxin_model.py @@ -16,6 +16,7 @@ class WenxinModel(BaseLLM): def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + # TODO load price_config from configs(db) return Wenxin( streaming=self.streaming, callbacks=self.callbacks, @@ -48,36 +49,6 @@ class WenxinModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) - def get_token_price(self, tokens: int, message_type: MessageType): - model_unit_prices = { - 'ernie-bot': { - 'prompt': decimal.Decimal('0.012'), - 'completion': decimal.Decimal('0.012'), - }, - 'ernie-bot-turbo': { - 'prompt': decimal.Decimal('0.008'), - 'completion': decimal.Decimal('0.008') - }, - 'bloomz-7b': { - 'prompt': decimal.Decimal('0.006'), - 'completion': decimal.Decimal('0.006') - } - } - - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: - unit_price = model_unit_prices[self.name]['prompt'] - else: - unit_price = model_unit_prices[self.name]['completion'] - - tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1k * unit_price - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - - def get_currency(self): - return 'RMB' - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) for k, v in provider_model_kwargs.items(): diff --git a/api/core/model_providers/rules/anthropic.json b/api/core/model_providers/rules/anthropic.json index 8e0bee4425..c0aac8617c 100644 --- a/api/core/model_providers/rules/anthropic.json +++ b/api/core/model_providers/rules/anthropic.json @@ -11,5 +11,19 @@ "quota_unit": "tokens", "quota_limit": 600000 }, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "price_config": { + "claude-instant-1": { + "prompt": "1.63", + "completion": "5.51", + "unit": "0.000001", + "currency": "USD" + }, + "claude-2": { + "prompt": "11.02", + "completion": "32.68", + "unit": "0.000001", + "currency": "USD" + } + } } \ No newline at end of file diff --git a/api/core/model_providers/rules/azure_openai.json b/api/core/model_providers/rules/azure_openai.json index 5badb07178..dfb354d5a4 100644 --- a/api/core/model_providers/rules/azure_openai.json +++ b/api/core/model_providers/rules/azure_openai.json @@ -3,5 +3,48 @@ "custom" ], "system_config": null, - "model_flexibility": "configurable" + "model_flexibility": "configurable", + "price_config":{ + "gpt-4": { + "prompt": "0.03", + "completion": "0.06", + "unit": "0.001", + "currency": "USD" + }, + "gpt-4-32k": { + "prompt": "0.06", + "completion": "0.12", + "unit": "0.001", + "currency": "USD" + }, + "gpt-35-turbo": { + "prompt": "0.0015", + "completion": "0.002", + "unit": "0.001", + "currency": "USD" + }, + "gpt-35-turbo-16k": { + "prompt": "0.003", + "completion": "0.004", + "unit": "0.001", + "currency": "USD" + }, + "text-davinci-002": { + "prompt": "0.02", + "completion": "0.02", + "unit": "0.001", + "currency": "USD" + }, + "text-davinci-003": { + "prompt": "0.02", + "completion": "0.02", + "unit": "0.001", + "currency": "USD" + }, + "text-embedding-ada-002":{ + "completion": "0.0001", + "unit": "0.001", + "currency": "USD" + } + } } \ No newline at end of file diff --git a/api/core/model_providers/rules/openai.json b/api/core/model_providers/rules/openai.json index e615de6063..aa2c9f3363 100644 --- a/api/core/model_providers/rules/openai.json +++ b/api/core/model_providers/rules/openai.json @@ -10,5 +10,42 @@ "quota_unit": "times", "quota_limit": 200 }, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "price_config": { + "gpt-4": { + "prompt": "0.03", + "completion": "0.06", + "unit": "0.001", + "currency": "USD" + }, + "gpt-4-32k": { + "prompt": "0.06", + "completion": "0.12", + "unit": "0.001", + "currency": "USD" + }, + "gpt-3.5-turbo": { + "prompt": "0.0015", + "completion": "0.002", + "unit": "0.001", + "currency": "USD" + }, + "gpt-3.5-turbo-16k": { + "prompt": "0.003", + "completion": "0.004", + "unit": "0.001", + "currency": "USD" + }, + "text-davinci-003": { + "prompt": "0.02", + "completion": "0.02", + "unit": "0.001", + "currency": "USD" + }, + "text-embedding-ada-002":{ + "completion": "0.0001", + "unit": "0.001", + "currency": "USD" + } + } } \ No newline at end of file diff --git a/api/core/model_providers/rules/wenxin.json b/api/core/model_providers/rules/wenxin.json index 0af3e61ec7..e5c136d326 100644 --- a/api/core/model_providers/rules/wenxin.json +++ b/api/core/model_providers/rules/wenxin.json @@ -3,5 +3,25 @@ "custom" ], "system_config": null, - "model_flexibility": "fixed" + "model_flexibility": "fixed", + "price_config": { + "ernie-bot": { + "prompt": "0.012", + "completion": "0.012", + "unit": "0.001", + "currency": "RMB" + }, + "ernie-bot-turbo": { + "prompt": "0.008", + "completion": "0.008", + "unit": "0.001", + "currency": "RMB" + }, + "bloomz-7b": { + "prompt": "0.006", + "completion": "0.006", + "unit": "0.001", + "currency": "RMB" + } + } } \ No newline at end of file