mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 06:45:58 +08:00
Fix/price calc (#862)
This commit is contained in:
parent
1c552ff23a
commit
fd0fc8f4fe
@ -140,10 +140,13 @@ class ConversationMessageTask:
|
|||||||
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
|
||||||
message_tokens = llm_message.prompt_tokens
|
message_tokens = llm_message.prompt_tokens
|
||||||
answer_tokens = llm_message.completion_tokens
|
answer_tokens = llm_message.completion_tokens
|
||||||
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
|
|
||||||
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
|
|
||||||
|
|
||||||
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
|
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||||
|
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||||
|
|
||||||
|
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||||
|
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||||
|
total_price = message_total_price + answer_total_price
|
||||||
|
|
||||||
self.message.message = llm_message.prompt
|
self.message.message = llm_message.prompt
|
||||||
self.message.message_tokens = message_tokens
|
self.message.message_tokens = message_tokens
|
||||||
@ -206,18 +209,15 @@ class ConversationMessageTask:
|
|||||||
|
|
||||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||||
agent_loop: AgentLoop):
|
agent_loop: AgentLoop):
|
||||||
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
|
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||||
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
|
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||||
|
|
||||||
loop_message_tokens = agent_loop.prompt_tokens
|
loop_message_tokens = agent_loop.prompt_tokens
|
||||||
loop_answer_tokens = agent_loop.completion_tokens
|
loop_answer_tokens = agent_loop.completion_tokens
|
||||||
|
|
||||||
loop_total_price = self.calc_total_price(
|
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||||
loop_message_tokens,
|
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||||
agent_message_unit_price,
|
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||||
loop_answer_tokens,
|
|
||||||
agent_answer_unit_price
|
|
||||||
)
|
|
||||||
|
|
||||||
message_agent_thought.observation = agent_loop.tool_output
|
message_agent_thought.observation = agent_loop.tool_output
|
||||||
message_agent_thought.tool_process_data = '' # currently not support
|
message_agent_thought.tool_process_data = '' # currently not support
|
||||||
@ -243,15 +243,6 @@ class ConversationMessageTask:
|
|||||||
|
|
||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
|
|
||||||
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
|
|
||||||
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
|
|
||||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
self._pub_handler.pub_end()
|
self._pub_handler.pub_end()
|
||||||
|
|
||||||
|
@ -278,7 +278,7 @@ class IndexingRunner:
|
|||||||
"total_segments": total_segments * 20,
|
"total_segments": total_segments * 20,
|
||||||
"tokens": total_segments * 2000,
|
"tokens": total_segments * 2000,
|
||||||
"total_price": '{:f}'.format(
|
"total_price": '{:f}'.format(
|
||||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||||
"currency": embedding_model.get_currency(),
|
"currency": embedding_model.get_currency(),
|
||||||
"qa_preview": document_qa_list,
|
"qa_preview": document_qa_list,
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
@ -286,7 +286,7 @@ class IndexingRunner:
|
|||||||
return {
|
return {
|
||||||
"total_segments": total_segments,
|
"total_segments": total_segments,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
|
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||||
"currency": embedding_model.get_currency(),
|
"currency": embedding_model.get_currency(),
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
}
|
}
|
||||||
@ -371,7 +371,7 @@ class IndexingRunner:
|
|||||||
"total_segments": total_segments * 20,
|
"total_segments": total_segments * 20,
|
||||||
"tokens": total_segments * 2000,
|
"tokens": total_segments * 2000,
|
||||||
"total_price": '{:f}'.format(
|
"total_price": '{:f}'.format(
|
||||||
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
|
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||||
"currency": embedding_model.get_currency(),
|
"currency": embedding_model.get_currency(),
|
||||||
"qa_preview": document_qa_list,
|
"qa_preview": document_qa_list,
|
||||||
"preview": preview_texts
|
"preview": preview_texts
|
||||||
|
@ -32,6 +32,15 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
super().__init__(model_provider, client, name)
|
super().__init__(model_provider, client, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_model_name(self) -> str:
|
||||||
|
"""
|
||||||
|
get base model name (not deployment)
|
||||||
|
|
||||||
|
:return: str
|
||||||
|
"""
|
||||||
|
return self.credentials.get("base_model_name")
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
get num tokens of text.
|
get num tokens of text.
|
||||||
@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding):
|
|||||||
# calculate the number of tokens in the encoded text
|
# calculate the number of tokens in the encoded text
|
||||||
return len(tokenized_text)
|
return len(tokenized_text)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int):
|
|
||||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
|
||||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
if isinstance(ex, openai.error.InvalidRequestError):
|
if isinstance(ex, openai.error.InvalidRequestError):
|
||||||
logging.warning("Invalid request to Azure OpenAI API.")
|
logging.warning("Invalid request to Azure OpenAI API.")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import decimal
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.schema.language_model import _get_token_ids_default_method
|
from langchain.schema.language_model import _get_token_ids_default_method
|
||||||
@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method
|
|||||||
from core.model_providers.models.base import BaseProviderModel
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
from core.model_providers.models.entity.model_params import ModelType
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BaseEmbedding(BaseProviderModel):
|
class BaseEmbedding(BaseProviderModel):
|
||||||
name: str
|
name: str
|
||||||
@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel):
|
|||||||
super().__init__(model_provider, client)
|
super().__init__(model_provider, client)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_model_name(self) -> str:
|
||||||
|
"""
|
||||||
|
get base model name
|
||||||
|
|
||||||
|
:return: str
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def price_config(self) -> dict:
|
||||||
|
def get_or_default():
|
||||||
|
default_price_config = {
|
||||||
|
'prompt': decimal.Decimal('0'),
|
||||||
|
'completion': decimal.Decimal('0'),
|
||||||
|
'unit': decimal.Decimal('0'),
|
||||||
|
'currency': 'USD'
|
||||||
|
}
|
||||||
|
rules = self.model_provider.get_rules()
|
||||||
|
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||||
|
price_config = {
|
||||||
|
'prompt': decimal.Decimal(price_config['prompt']),
|
||||||
|
'completion': decimal.Decimal(price_config['completion']),
|
||||||
|
'unit': decimal.Decimal(price_config['unit']),
|
||||||
|
'currency': price_config['currency']
|
||||||
|
}
|
||||||
|
return price_config
|
||||||
|
|
||||||
|
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||||
|
|
||||||
|
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||||
|
return self._price_config
|
||||||
|
|
||||||
|
def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
|
||||||
|
"""
|
||||||
|
calc tokens total price.
|
||||||
|
|
||||||
|
:param tokens:
|
||||||
|
:return: decimal.Decimal('0.0000001')
|
||||||
|
"""
|
||||||
|
unit_price = self._price_config['completion']
|
||||||
|
unit = self._price_config['unit']
|
||||||
|
total_price = tokens * unit_price * unit
|
||||||
|
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||||
|
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||||
|
return total_price
|
||||||
|
|
||||||
|
def get_tokens_unit_price(self) -> decimal.Decimal:
|
||||||
|
"""
|
||||||
|
get token price.
|
||||||
|
|
||||||
|
:return: decimal.Decimal('0.0001')
|
||||||
|
|
||||||
|
"""
|
||||||
|
unit_price = self._price_config['completion']
|
||||||
|
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||||
|
logger.debug(f'unit_price:{unit_price}')
|
||||||
|
return unit_price
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
get num tokens of text.
|
get num tokens of text.
|
||||||
@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel):
|
|||||||
|
|
||||||
return len(_get_token_ids_default_method(text))
|
return len(_get_token_ids_default_method(text))
|
||||||
|
|
||||||
def get_token_price(self, tokens: int):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'USD'
|
"""
|
||||||
|
get token currency.
|
||||||
|
|
||||||
|
:return: get from price config, default 'USD'
|
||||||
|
"""
|
||||||
|
currency = self._price_config['currency']
|
||||||
|
return currency
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
super().__init__(model_provider, client, name)
|
super().__init__(model_provider, client, name)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int):
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'RMB'
|
return 'RMB'
|
||||||
|
|
||||||
|
@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding):
|
|||||||
# calculate the number of tokens in the encoded text
|
# calculate the number of tokens in the encoded text
|
||||||
return len(tokenized_text)
|
return len(tokenized_text)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int):
|
|
||||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = tokens_per_1k * decimal.Decimal('0.0001')
|
|
||||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
if isinstance(ex, openai.error.InvalidRequestError):
|
if isinstance(ex, openai.error.InvalidRequestError):
|
||||||
logging.warning("Invalid request to OpenAI API.")
|
logging.warning("Invalid request to OpenAI API.")
|
||||||
|
@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
super().__init__(model_provider, client, name)
|
super().__init__(model_provider, client, name)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int):
|
|
||||||
# replicate only pay for prediction seconds
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
if isinstance(ex, (ModelError, ReplicateError)):
|
if isinstance(ex, (ModelError, ReplicateError)):
|
||||||
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
return LLMBadRequestError(f"Replicate: {str(ex)}")
|
||||||
|
@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
model_unit_prices = {
|
|
||||||
'claude-instant-1': {
|
|
||||||
'prompt': decimal.Decimal('1.63'),
|
|
||||||
'completion': decimal.Decimal('5.51'),
|
|
||||||
},
|
|
||||||
'claude-2': {
|
|
||||||
'prompt': decimal.Decimal('11.02'),
|
|
||||||
'completion': decimal.Decimal('32.68'),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
||||||
unit_price = model_unit_prices[self.name]['prompt']
|
|
||||||
else:
|
|
||||||
unit_price = model_unit_prices[self.name]['completion']
|
|
||||||
|
|
||||||
tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = tokens_per_1m * unit_price
|
|
||||||
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
for k, v in provider_model_kwargs.items():
|
for k, v in provider_model_kwargs.items():
|
||||||
|
@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM):
|
|||||||
self.model_mode = ModelMode.COMPLETION
|
self.model_mode = ModelMode.COMPLETION
|
||||||
else:
|
else:
|
||||||
self.model_mode = ModelMode.CHAT
|
self.model_mode = ModelMode.CHAT
|
||||||
|
|
||||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||||
|
|
||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
@ -84,6 +83,15 @@ class AzureOpenAIModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return self._client.generate([prompts], stop, callbacks)
|
return self._client.generate([prompts], stop, callbacks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_model_name(self) -> str:
|
||||||
|
"""
|
||||||
|
get base model name (not deployment)
|
||||||
|
|
||||||
|
:return: str
|
||||||
|
"""
|
||||||
|
return self.credentials.get("base_model_name")
|
||||||
|
|
||||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||||
"""
|
"""
|
||||||
get num tokens of prompt messages.
|
get num tokens of prompt messages.
|
||||||
@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
model_unit_prices = {
|
|
||||||
'gpt-4': {
|
|
||||||
'prompt': decimal.Decimal('0.03'),
|
|
||||||
'completion': decimal.Decimal('0.06'),
|
|
||||||
},
|
|
||||||
'gpt-4-32k': {
|
|
||||||
'prompt': decimal.Decimal('0.06'),
|
|
||||||
'completion': decimal.Decimal('0.12')
|
|
||||||
},
|
|
||||||
'gpt-35-turbo': {
|
|
||||||
'prompt': decimal.Decimal('0.0015'),
|
|
||||||
'completion': decimal.Decimal('0.002')
|
|
||||||
},
|
|
||||||
'gpt-35-turbo-16k': {
|
|
||||||
'prompt': decimal.Decimal('0.003'),
|
|
||||||
'completion': decimal.Decimal('0.004')
|
|
||||||
},
|
|
||||||
'text-davinci-003': {
|
|
||||||
'prompt': decimal.Decimal('0.02'),
|
|
||||||
'completion': decimal.Decimal('0.02')
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
base_model_name = self.credentials.get("base_model_name")
|
|
||||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
||||||
unit_price = model_unit_prices[base_model_name]['prompt']
|
|
||||||
else:
|
|
||||||
unit_price = model_unit_prices[base_model_name]['completion']
|
|
||||||
|
|
||||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = tokens_per_1k * unit_price
|
|
||||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
if self.name == 'text-davinci-003':
|
if self.name == 'text-davinci-003':
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Optional, Any, Union
|
from typing import List, Optional, Any, Union
|
||||||
|
import decimal
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||||
@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
|
|||||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||||
from core.model_providers.providers.base import BaseModelProvider
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
from core.third_party.langchain.llms.fake import FakeLLM
|
from core.third_party.langchain.llms.fake import FakeLLM
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(BaseProviderModel):
|
class BaseLLM(BaseProviderModel):
|
||||||
@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
|
|||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_model_name(self) -> str:
|
||||||
|
"""
|
||||||
|
get llm base model name
|
||||||
|
|
||||||
|
:return: str
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def price_config(self) -> dict:
|
||||||
|
def get_or_default():
|
||||||
|
default_price_config = {
|
||||||
|
'prompt': decimal.Decimal('0'),
|
||||||
|
'completion': decimal.Decimal('0'),
|
||||||
|
'unit': decimal.Decimal('0'),
|
||||||
|
'currency': 'USD'
|
||||||
|
}
|
||||||
|
rules = self.model_provider.get_rules()
|
||||||
|
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
|
||||||
|
price_config = {
|
||||||
|
'prompt': decimal.Decimal(price_config['prompt']),
|
||||||
|
'completion': decimal.Decimal(price_config['completion']),
|
||||||
|
'unit': decimal.Decimal(price_config['unit']),
|
||||||
|
'currency': price_config['currency']
|
||||||
|
}
|
||||||
|
return price_config
|
||||||
|
|
||||||
|
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
|
||||||
|
|
||||||
|
logger.debug(f"model: {self.name} price_config: {self._price_config}")
|
||||||
|
return self._price_config
|
||||||
|
|
||||||
def run(self, messages: List[PromptMessage],
|
def run(self, messages: List[PromptMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
def calc_tokens_price(self, tokens:int, message_type: MessageType):
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
"""
|
"""
|
||||||
get token price.
|
calc tokens total price.
|
||||||
|
|
||||||
:param tokens:
|
:param tokens:
|
||||||
:param message_type:
|
:param message_type:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||||
|
unit_price = self.price_config['prompt']
|
||||||
|
else:
|
||||||
|
unit_price = self.price_config['completion']
|
||||||
|
unit = self.price_config['unit']
|
||||||
|
|
||||||
|
total_price = tokens * unit_price * unit
|
||||||
|
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||||
|
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||||
|
return total_price
|
||||||
|
|
||||||
|
def get_tokens_unit_price(self, message_type: MessageType):
|
||||||
|
"""
|
||||||
|
get token price.
|
||||||
|
|
||||||
|
:param message_type:
|
||||||
|
:return: decimal.Decimal('0.0001')
|
||||||
|
"""
|
||||||
|
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||||
|
unit_price = self.price_config['prompt']
|
||||||
|
else:
|
||||||
|
unit_price = self.price_config['completion']
|
||||||
|
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
|
||||||
|
logging.debug(f"unit_price={unit_price}")
|
||||||
|
return unit_price
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
"""
|
"""
|
||||||
get token currency.
|
get token currency.
|
||||||
|
|
||||||
:return:
|
:return: get from price config, default 'USD'
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
currency = self.price_config['currency']
|
||||||
|
return currency
|
||||||
|
|
||||||
def get_model_kwargs(self):
|
def get_model_kwargs(self):
|
||||||
return self.model_kwargs
|
return self.model_kwargs
|
||||||
|
@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens(prompts), 0)
|
return max(self._client.get_num_tokens(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'RMB'
|
return 'RMB'
|
||||||
|
|
||||||
|
@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return self._client.get_num_tokens(prompts)
|
return self._client.get_num_tokens(prompts)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
# not support calc price
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
self.client.model_kwargs = provider_model_kwargs
|
self.client.model_kwargs = provider_model_kwargs
|
||||||
|
@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens(prompts), 0)
|
return max(self._client.get_num_tokens(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'RMB'
|
return 'RMB'
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ class OpenAIModel(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
self.model_mode = ModelMode.CHAT
|
self.model_mode = ModelMode.CHAT
|
||||||
|
|
||||||
|
# TODO load price config from configs(db)
|
||||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||||
|
|
||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
model_unit_prices = {
|
|
||||||
'gpt-4': {
|
|
||||||
'prompt': decimal.Decimal('0.03'),
|
|
||||||
'completion': decimal.Decimal('0.06'),
|
|
||||||
},
|
|
||||||
'gpt-4-32k': {
|
|
||||||
'prompt': decimal.Decimal('0.06'),
|
|
||||||
'completion': decimal.Decimal('0.12')
|
|
||||||
},
|
|
||||||
'gpt-3.5-turbo': {
|
|
||||||
'prompt': decimal.Decimal('0.0015'),
|
|
||||||
'completion': decimal.Decimal('0.002')
|
|
||||||
},
|
|
||||||
'gpt-3.5-turbo-16k': {
|
|
||||||
'prompt': decimal.Decimal('0.003'),
|
|
||||||
'completion': decimal.Decimal('0.004')
|
|
||||||
},
|
|
||||||
'text-davinci-003': {
|
|
||||||
'prompt': decimal.Decimal('0.02'),
|
|
||||||
'completion': decimal.Decimal('0.02')
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
||||||
unit_price = model_unit_prices[self.name]['prompt']
|
|
||||||
else:
|
|
||||||
unit_price = model_unit_prices[self.name]['completion']
|
|
||||||
|
|
||||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = tokens_per_1k * unit_price
|
|
||||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
if self.name in COMPLETION_MODELS:
|
if self.name in COMPLETION_MODELS:
|
||||||
|
@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM):
|
|||||||
|
|
||||||
return self._client.get_num_tokens(prompts)
|
return self._client.get_num_tokens(prompts)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
# replicate only pay for prediction seconds
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'USD'
|
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
self.client.input = provider_model_kwargs
|
self.client.input = provider_model_kwargs
|
||||||
|
@ -50,9 +50,6 @@ class SparkModel(BaseLLM):
|
|||||||
contents = [message.content for message in messages]
|
contents = [message.content for message in messages]
|
||||||
return max(self._client.get_num_tokens("".join(contents)), 0)
|
return max(self._client.get_num_tokens("".join(contents)), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'RMB'
|
return 'RMB'
|
||||||
|
|
||||||
|
@ -53,9 +53,6 @@ class TongyiModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens(prompts), 0)
|
return max(self._client.get_num_tokens(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
return decimal.Decimal('0')
|
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'RMB'
|
return 'RMB'
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ class WenxinModel(BaseLLM):
|
|||||||
|
|
||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
|
# TODO load price_config from configs(db)
|
||||||
return Wenxin(
|
return Wenxin(
|
||||||
streaming=self.streaming,
|
streaming=self.streaming,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
@ -48,36 +49,6 @@ class WenxinModel(BaseLLM):
|
|||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens(prompts), 0)
|
return max(self._client.get_num_tokens(prompts), 0)
|
||||||
|
|
||||||
def get_token_price(self, tokens: int, message_type: MessageType):
|
|
||||||
model_unit_prices = {
|
|
||||||
'ernie-bot': {
|
|
||||||
'prompt': decimal.Decimal('0.012'),
|
|
||||||
'completion': decimal.Decimal('0.012'),
|
|
||||||
},
|
|
||||||
'ernie-bot-turbo': {
|
|
||||||
'prompt': decimal.Decimal('0.008'),
|
|
||||||
'completion': decimal.Decimal('0.008')
|
|
||||||
},
|
|
||||||
'bloomz-7b': {
|
|
||||||
'prompt': decimal.Decimal('0.006'),
|
|
||||||
'completion': decimal.Decimal('0.006')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
|
||||||
unit_price = model_unit_prices[self.name]['prompt']
|
|
||||||
else:
|
|
||||||
unit_price = model_unit_prices[self.name]['completion']
|
|
||||||
|
|
||||||
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
|
|
||||||
rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
total_price = tokens_per_1k * unit_price
|
|
||||||
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
|
||||||
|
|
||||||
def get_currency(self):
|
|
||||||
return 'RMB'
|
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
for k, v in provider_model_kwargs.items():
|
for k, v in provider_model_kwargs.items():
|
||||||
|
@ -11,5 +11,19 @@
|
|||||||
"quota_unit": "tokens",
|
"quota_unit": "tokens",
|
||||||
"quota_limit": 600000
|
"quota_limit": 600000
|
||||||
},
|
},
|
||||||
"model_flexibility": "fixed"
|
"model_flexibility": "fixed",
|
||||||
|
"price_config": {
|
||||||
|
"claude-instant-1": {
|
||||||
|
"prompt": "1.63",
|
||||||
|
"completion": "5.51",
|
||||||
|
"unit": "0.000001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"claude-2": {
|
||||||
|
"prompt": "11.02",
|
||||||
|
"completion": "32.68",
|
||||||
|
"unit": "0.000001",
|
||||||
|
"currency": "USD"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -3,5 +3,48 @@
|
|||||||
"custom"
|
"custom"
|
||||||
],
|
],
|
||||||
"system_config": null,
|
"system_config": null,
|
||||||
"model_flexibility": "configurable"
|
"model_flexibility": "configurable",
|
||||||
|
"price_config":{
|
||||||
|
"gpt-4": {
|
||||||
|
"prompt": "0.03",
|
||||||
|
"completion": "0.06",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"gpt-4-32k": {
|
||||||
|
"prompt": "0.06",
|
||||||
|
"completion": "0.12",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"gpt-35-turbo": {
|
||||||
|
"prompt": "0.0015",
|
||||||
|
"completion": "0.002",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"gpt-35-turbo-16k": {
|
||||||
|
"prompt": "0.003",
|
||||||
|
"completion": "0.004",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"text-davinci-002": {
|
||||||
|
"prompt": "0.02",
|
||||||
|
"completion": "0.02",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"text-davinci-003": {
|
||||||
|
"prompt": "0.02",
|
||||||
|
"completion": "0.02",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"text-embedding-ada-002":{
|
||||||
|
"completion": "0.0001",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -10,5 +10,42 @@
|
|||||||
"quota_unit": "times",
|
"quota_unit": "times",
|
||||||
"quota_limit": 200
|
"quota_limit": 200
|
||||||
},
|
},
|
||||||
"model_flexibility": "fixed"
|
"model_flexibility": "fixed",
|
||||||
|
"price_config": {
|
||||||
|
"gpt-4": {
|
||||||
|
"prompt": "0.03",
|
||||||
|
"completion": "0.06",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"gpt-4-32k": {
|
||||||
|
"prompt": "0.06",
|
||||||
|
"completion": "0.12",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo": {
|
||||||
|
"prompt": "0.0015",
|
||||||
|
"completion": "0.002",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo-16k": {
|
||||||
|
"prompt": "0.003",
|
||||||
|
"completion": "0.004",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"text-davinci-003": {
|
||||||
|
"prompt": "0.02",
|
||||||
|
"completion": "0.02",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
},
|
||||||
|
"text-embedding-ada-002":{
|
||||||
|
"completion": "0.0001",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "USD"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -3,5 +3,25 @@
|
|||||||
"custom"
|
"custom"
|
||||||
],
|
],
|
||||||
"system_config": null,
|
"system_config": null,
|
||||||
"model_flexibility": "fixed"
|
"model_flexibility": "fixed",
|
||||||
|
"price_config": {
|
||||||
|
"ernie-bot": {
|
||||||
|
"prompt": "0.012",
|
||||||
|
"completion": "0.012",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"ernie-bot-turbo": {
|
||||||
|
"prompt": "0.008",
|
||||||
|
"completion": "0.008",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
},
|
||||||
|
"bloomz-7b": {
|
||||||
|
"prompt": "0.006",
|
||||||
|
"completion": "0.006",
|
||||||
|
"unit": "0.001",
|
||||||
|
"currency": "RMB"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user