mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 16:28:58 +08:00
feat: record price unit in messages (#919)
This commit is contained in:
parent
920fb6d0e1
commit
0a0d63457d
@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
|
||||
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.conversation_message_task import ConversationMessageTask
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._current_loop.status = 'llm_end'
|
||||
if response.llm_output:
|
||||
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
else:
|
||||
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.prompt)]
|
||||
)
|
||||
completion_generation = response.generations[0][0]
|
||||
if isinstance(completion_generation, ChatGeneration):
|
||||
completion_message = completion_generation.message
|
||||
@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
if response.llm_output:
|
||||
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
|
||||
else:
|
||||
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
|
||||
[PromptMessage(content=self._current_loop.completion)]
|
||||
)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
@ -119,9 +119,11 @@ class ConversationMessageTask:
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency=self.model_instance.get_currency(),
|
||||
@ -142,7 +144,9 @@ class ConversationMessageTask:
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
@ -151,9 +155,11 @@ class ConversationMessageTask:
|
||||
self.message.message = llm_message.prompt
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
self.message.answer_price_unit = answer_price_unit
|
||||
self.message.provider_response_latency = llm_message.latency
|
||||
self.message.total_price = total_price
|
||||
|
||||
@ -195,7 +201,9 @@ class ConversationMessageTask:
|
||||
tool=agent_loop.tool_name,
|
||||
tool_input=agent_loop.tool_input,
|
||||
message=agent_loop.prompt,
|
||||
message_price_unit=0,
|
||||
answer=agent_loop.completion,
|
||||
answer_price_unit=0,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
@ -210,7 +218,9 @@ class ConversationMessageTask:
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
|
||||
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
@ -223,8 +233,10 @@ class ConversationMessageTask:
|
||||
message_agent_thought.tool_process_data = '' # currently not support
|
||||
message_agent_thought.message_token = loop_message_tokens
|
||||
message_agent_thought.message_unit_price = agent_message_unit_price
|
||||
message_agent_thought.message_price_unit = agent_message_price_unit
|
||||
message_agent_thought.answer_token = loop_answer_tokens
|
||||
message_agent_thought.answer_unit_price = agent_answer_unit_price
|
||||
message_agent_thought.answer_price_unit = agent_answer_price_unit
|
||||
message_agent_thought.latency = agent_loop.latency
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
|
@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calc_tokens_price(self, tokens:int, message_type: MessageType):
|
||||
def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
calc tokens total price.
|
||||
|
||||
@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel):
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
unit = self.price_config['unit']
|
||||
unit = self.get_price_unit(message_type)
|
||||
|
||||
total_price = tokens * unit_price * unit
|
||||
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
|
||||
return total_price
|
||||
|
||||
def get_tokens_unit_price(self, message_type: MessageType):
|
||||
def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get token price.
|
||||
|
||||
@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel):
|
||||
logging.debug(f"unit_price={unit_price}")
|
||||
return unit_price
|
||||
|
||||
def get_currency(self):
|
||||
def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
|
||||
"""
|
||||
get price unit.
|
||||
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
|
||||
price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
|
||||
logging.debug(f"price_unit={price_unit}")
|
||||
return price_unit
|
||||
|
||||
def get_currency(self) -> str:
|
||||
"""
|
||||
get token currency.
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
"""add message price unit
|
||||
|
||||
Revision ID: 853f9b9cd3b6
|
||||
Revises: e8883b0148c9
|
||||
Create Date: 2023-08-19 17:01:57.471562
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '853f9b9cd3b6'
|
||||
down_revision = 'e8883b0148c9'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_column('answer_price_unit')
|
||||
batch_op.drop_column('message_price_unit')
|
||||
|
||||
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||
batch_op.drop_column('answer_price_unit')
|
||||
batch_op.drop_column('message_price_unit')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -421,9 +421,11 @@ class Message(db.Model):
|
||||
message = db.Column(db.JSON, nullable=False)
|
||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
answer = db.Column(db.Text, nullable=False)
|
||||
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0'))
|
||||
total_price = db.Column(db.Numeric(10, 7))
|
||||
currency = db.Column(db.String(255), nullable=False)
|
||||
@ -705,9 +707,11 @@ class MessageAgentThought(db.Model):
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
message_token = db.Column(db.Integer, nullable=True)
|
||||
message_unit_price = db.Column(db.Numeric, nullable=True)
|
||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
answer_token = db.Column(db.Integer, nullable=True)
|
||||
answer_unit_price = db.Column(db.Numeric, nullable=True)
|
||||
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
|
||||
tokens = db.Column(db.Integer, nullable=True)
|
||||
total_price = db.Column(db.Numeric, nullable=True)
|
||||
currency = db.Column(db.String, nullable=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user