Fix/price calc (#862)

This commit is contained in:
Krasus.Chen 2023-08-19 16:41:35 +08:00 committed by GitHub
parent 1c552ff23a
commit fd0fc8f4fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 288 additions and 230 deletions

View File

@ -140,10 +140,13 @@ class ConversationMessageTask:
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens
@ -206,18 +209,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
loop_total_price = self.calc_total_price( loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_tokens, loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
agent_message_unit_price, loop_total_price = loop_message_total_price + loop_answer_total_price
loop_answer_tokens,
agent_answer_unit_price
)
message_agent_thought.observation = agent_loop.tool_output message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support message_agent_thought.tool_process_data = '' # currently not support
@ -243,15 +243,6 @@ class ConversationMessageTask:
db.session.add(dataset_query) db.session.add(dataset_query)
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def end(self): def end(self):
self._pub_handler.pub_end() self._pub_handler.pub_end()

View File

@ -278,7 +278,7 @@ class IndexingRunner:
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
@ -286,7 +286,7 @@ class IndexingRunner:
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"preview": preview_texts "preview": preview_texts
} }
@ -371,7 +371,7 @@ class IndexingRunner:
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts

View File

@ -32,6 +32,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
""" """
get num tokens of text. get num tokens of text.
@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text # calculate the number of tokens in the encoded text
return len(tokenized_text) return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.") logging.warning("Invalid request to Azure OpenAI API.")

View File

@ -1,5 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Any from typing import Any
import decimal
import tiktoken import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method from langchain.schema.language_model import _get_token_ids_default_method
@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
import logging
logger = logging.getLogger(__name__)
class BaseEmbedding(BaseProviderModel): class BaseEmbedding(BaseProviderModel):
name: str name: str
@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
super().__init__(model_provider, client) super().__init__(model_provider, client)
self.name = name self.name = name
@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
"""
calc tokens total price.
:param tokens:
:return: decimal.Decimal('0.0000001')
"""
unit_price = self._price_config['completion']
unit = self._price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self) -> decimal.Decimal:
"""
get token price.
:return: decimal.Decimal('0.0001')
"""
unit_price = self._price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logger.debug(f'unit_price:{unit_price}')
return unit_price
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
""" """
get num tokens of text. get num tokens of text.
@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
return len(_get_token_ids_default_method(text)) return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self): def get_currency(self):
return 'USD' """
get token currency.
:return: get from price config, default 'USD'
"""
currency = self._price_config['currency']
return currency
@abstractmethod @abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:

View File

@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'

View File

@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
# calculate the number of tokens in the encoded text # calculate the number of tokens in the encoded text
return len(tokenized_text) return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.") logging.warning("Invalid request to OpenAI API.")

View File

@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)): if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}") return LLMBadRequestError(f"Replicate: {str(ex)}")

View File

@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): for k, v in provider_model_kwargs.items():

View File

@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
self.model_mode = ModelMode.COMPLETION self.model_mode = ModelMode.COMPLETION
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT
super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any: def _init_client(self) -> Any:
@ -84,6 +83,15 @@ class AzureOpenAIModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks) return self._client.generate([prompts], stop, callbacks)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")
def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
get num tokens of prompt messages. get num tokens of prompt messages.
@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
else: else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003': if self.name == 'text-davinci-003':

View File

@ -1,5 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Any, Union from typing import List, Optional, Any, Union
import decimal
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)
class BaseLLM(BaseProviderModel): class BaseLLM(BaseProviderModel):
@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
def _init_client(self) -> Any: def _init_client(self) -> Any:
raise NotImplementedError raise NotImplementedError
@property
def base_model_name(self) -> str:
"""
get llm base model name
:return: str
"""
return self.name
@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config
def run(self, messages: List[PromptMessage], def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod def calc_tokens_price(self, tokens:int, message_type: MessageType):
def get_token_price(self, tokens: int, message_type: MessageType):
""" """
get token price. calc tokens total price.
:param tokens: :param tokens:
:param message_type: :param message_type:
:return: :return:
""" """
raise NotImplementedError if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit = self.price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price
def get_tokens_unit_price(self, message_type: MessageType):
"""
get token price.
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"unit_price={unit_price}")
return unit_price
@abstractmethod
def get_currency(self): def get_currency(self):
""" """
get token currency. get token currency.
:return: :return: get from price config, default 'USD'
""" """
raise NotImplementedError currency = self.price_config['currency']
return currency
def get_model_kwargs(self): def get_model_kwargs(self):
return self.model_kwargs return self.model_kwargs

View File

@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'

View File

@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs self.client.model_kwargs = provider_model_kwargs

View File

@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'

View File

@ -47,6 +47,7 @@ class OpenAIModel(BaseLLM):
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT
# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any: def _init_client(self) -> Any:
@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
else: else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS: if self.name in COMPLETION_MODELS:

View File

@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)
def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs self.client.input = provider_model_kwargs

View File

@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
contents = [message.content for message in messages] contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0) return max(self._client.get_num_tokens("".join(contents)), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'

View File

@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'

View File

@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin( return Wenxin(
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): for k, v in provider_model_kwargs.items():

View File

@ -11,5 +11,19 @@
"quota_unit": "tokens", "quota_unit": "tokens",
"quota_limit": 600000 "quota_limit": 600000
}, },
"model_flexibility": "fixed" "model_flexibility": "fixed",
"price_config": {
"claude-instant-1": {
"prompt": "1.63",
"completion": "5.51",
"unit": "0.000001",
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"unit": "0.000001",
"currency": "USD"
}
}
} }

View File

@ -3,5 +3,48 @@
"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable" "model_flexibility": "configurable",
"price_config":{
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-002": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
} }

View File

@ -10,5 +10,42 @@
"quota_unit": "times", "quota_unit": "times",
"quota_limit": 200 "quota_limit": 200
}, },
"model_flexibility": "fixed" "model_flexibility": "fixed",
"price_config": {
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
} }

View File

@ -3,5 +3,25 @@
"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed" "model_flexibility": "fixed",
"price_config": {
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot-turbo": {
"prompt": "0.008",
"completion": "0.008",
"unit": "0.001",
"currency": "RMB"
},
"bloomz-7b": {
"prompt": "0.006",
"completion": "0.006",
"unit": "0.001",
"currency": "RMB"
}
}
} }