mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 01:08:57 +08:00
Fix/price calc (#862)
This commit is contained in:
parent
1c552ff23a
commit
fd0fc8f4fe
@ -140,10 +140,13 @@ class ConversationMessageTask:
|
||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
|
||||
|
||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
@ -206,18 +209,15 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_total_price = self.calc_total_price(
|
||||
loop_message_tokens,
|
||||
agent_message_unit_price,
|
||||
loop_answer_tokens,
|
||||
agent_answer_unit_price
|
||||
)
|
||||
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
message_agent_thought.observation = agent_loop.tool_output
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
@ -243,15 +243,6 @@ class ConversationMessageTask:
|
||||
|
||||
db.session.add(dataset_query)
|
||||
|
||||
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
|
||||
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def end(self):
|
||||
self._pub_handler.pub_end()
|
||||
|
||||
|
@ -278,7 +278,7 @@ class IndexingRunner:
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@ -286,7 +286,7 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"preview": preview_texts
|
||||
}
|
||||
@ -371,7 +371,7 @@ class IndexingRunner:
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
|
@ -31,6 +31,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to Azure OpenAI API.")
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
import decimal
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseEmbedding(BaseProviderModel):
|
||||
name: str
|
||||
@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
|
||||
super().__init__(model_provider, client)
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:return: decimal.Decimal('0.0000001')
|
||||
"""
|
||||
unit_price = self._price_config['completion']
|
||||
unit = self._price_config['unit']
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:return: decimal.Decimal('0.0001')
|
||||
|
||||
"""
|
||||
unit_price = self._price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logger.debug(f'unit_price:{unit_price}')
|
||||
return unit_price
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""
|
||||
get num tokens of text.
|
||||
@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
|
||||
|
||||
return len(_get_token_ids_default_method(text))
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return 0
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
currency = self._price_config['currency']
|
||||
return currency
|
||||
|
||||
@abstractmethod
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
|
@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
|
||||
# calculate the number of tokens in the encoded text
|
||||
return len(tokenized_text)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to OpenAI API.")
|
||||
|
@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def get_token_price(self, tokens: int):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||
|
@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'claude-instant-1': {
|
||||
'prompt': decimal.Decimal('1.63'),
|
||||
'completion': decimal.Decimal('5.51'),
|
||||
},
|
||||
'claude-2': {
|
||||
'prompt': decimal.Decimal('11.02'),
|
||||
'completion': decimal.Decimal('32.68'),
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1m * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
|
@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM):
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get base model name (not deployment)
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.credentials.get("base_model_name")
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-35-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-35-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
base_model_name = self.credentials.get("base_model_name")
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[base_model_name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name == 'text-davinci-003':
|
||||
|
@ -1,5 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union
|
||||
import decimal
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(BaseProviderModel):
|
||||
@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
|
||||
def _init_client(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
"""
|
||||
get llm base model name
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def price_config(self) -> dict:
|
||||
def get_or_default():
|
||||
default_price_config = {
|
||||
'prompt': decimal.Decimal('0'),
|
||||
'completion': decimal.Decimal('0'),
|
||||
'unit': decimal.Decimal('0'),
|
||||
'currency': 'USD'
|
||||
}
|
||||
rules = self.model_provider.get_rules()
|
||||
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||
price_config = {
|
||||
'prompt': decimal.Decimal(price_config['prompt']),
|
||||
'completion': decimal.Decimal(price_config['completion']),
|
||||
'unit': decimal.Decimal(price_config['unit']),
|
||||
'currency': price_config['currency']
|
||||
}
|
||||
return price_config
|
||||
|
||||
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||
|
||||
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||
return self._price_config
|
||||
|
||||
def run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
def calc_tokens_price(self, tokens:int, message_type: MessageType):
|
||||
"""
|
||||
get token price.
|
||||
calc tokens total price.
|
||||
|
||||
:param tokens:
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.price_config['unit']
|
||||
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self, message_type: MessageType):
|
||||
"""
|
||||
get token price.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"unit_price={unit_price}")
|
||||
return unit_price
|
||||
|
||||
@abstractmethod
|
||||
def get_currency(self):
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
:return:
|
||||
:return: get from price config, default 'USD'
|
||||
"""
|
||||
raise NotImplementedError
|
||||
currency = self.price_config['currency']
|
||||
return currency
|
||||
|
||||
def get_model_kwargs(self):
|
||||
return self.model_kwargs
|
||||
|
@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# not support calc price
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM):
|
||||
self.model_mode = ModelMode.COMPLETION
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'gpt-4': {
|
||||
'prompt': decimal.Decimal('0.03'),
|
||||
'completion': decimal.Decimal('0.06'),
|
||||
},
|
||||
'gpt-4-32k': {
|
||||
'prompt': decimal.Decimal('0.06'),
|
||||
'completion': decimal.Decimal('0.12')
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'prompt': decimal.Decimal('0.0015'),
|
||||
'completion': decimal.Decimal('0.002')
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
'prompt': decimal.Decimal('0.003'),
|
||||
'completion': decimal.Decimal('0.004')
|
||||
},
|
||||
'text-davinci-003': {
|
||||
'prompt': decimal.Decimal('0.02'),
|
||||
'completion': decimal.Decimal('0.02')
|
||||
},
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
if self.name in COMPLETION_MODELS:
|
||||
|
@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
|
||||
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
# replicate only pay for prediction seconds
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'USD'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.input = provider_model_kwargs
|
||||
|
@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
|
||||
contents = [message.content for message in messages]
|
||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
return decimal.Decimal('0')
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
|
@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
||||
model_unit_prices = {
|
||||
'ernie-bot': {
|
||||
'prompt': decimal.Decimal('0.012'),
|
||||
'completion': decimal.Decimal('0.012'),
|
||||
},
|
||||
'ernie-bot-turbo': {
|
||||
'prompt': decimal.Decimal('0.008'),
|
||||
'completion': decimal.Decimal('0.008')
|
||||
},
|
||||
'bloomz-7b': {
|
||||
'prompt': decimal.Decimal('0.006'),
|
||||
'completion': decimal.Decimal('0.006')
|
||||
}
|
||||
}
|
||||
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
unit_price = model_unit_prices[self.name]['prompt']
|
||||
else:
|
||||
unit_price = model_unit_prices[self.name]['completion']
|
||||
|
||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
||||
rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
total_price = tokens_per_1k * unit_price
|
||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
|
@ -11,5 +11,19 @@
|
||||
"quota_unit": "tokens",
|
||||
"quota_limit": 600000
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"claude-instant-1": {
|
||||
"prompt": "1.63",
|
||||
"completion": "5.51",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"claude-2": {
|
||||
"prompt": "11.02",
|
||||
"completion": "32.68",
|
||||
"unit": "0.000001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
@ -3,5 +3,48 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "configurable"
|
||||
"model_flexibility": "configurable",
|
||||
"price_config":{
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-35-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-002": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
@ -10,5 +10,42 @@
|
||||
"quota_unit": "times",
|
||||
"quota_limit": 200
|
||||
},
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"gpt-4": {
|
||||
"prompt": "0.03",
|
||||
"completion": "0.06",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-32k": {
|
||||
"prompt": "0.06",
|
||||
"completion": "0.12",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"prompt": "0.003",
|
||||
"completion": "0.004",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"prompt": "0.02",
|
||||
"completion": "0.02",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"text-embedding-ada-002":{
|
||||
"completion": "0.0001",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
}
|
||||
}
|
||||
}
|
@ -3,5 +3,25 @@
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed"
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"ernie-bot": {
|
||||
"prompt": "0.012",
|
||||
"completion": "0.012",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"ernie-bot-turbo": {
|
||||
"prompt": "0.008",
|
||||
"completion": "0.008",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
},
|
||||
"bloomz-7b": {
|
||||
"prompt": "0.006",
|
||||
"completion": "0.006",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user