From 0a0d63457da7538d847e2f93a64371507233f188 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 19 Aug 2023 18:51:40 +0800 Subject: [PATCH] feat: record price unit in messages (#919) --- .../agent_loop_gather_callback_handler.py | 9 ++++ api/core/conversation_message_task.py | 12 ++++++ api/core/model_providers/models/llm/base.py | 24 +++++++++-- .../853f9b9cd3b6_add_message_price_unit.py | 43 +++++++++++++++++++ api/models/model.py | 4 ++ 5 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index 72c6018323..c8cc043478 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration from core.callback_handler.entity.agent_loop import AgentLoop from core.conversation_message_task import ConversationMessageTask +from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.llm.base import BaseLLM @@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._current_loop.status = 'llm_end' if response.llm_output: self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] + else: + self._current_loop.prompt_tokens = self.model_instant.get_num_tokens( + [PromptMessage(content=self._current_loop.prompt)] + ) completion_generation = response.generations[0][0] if isinstance(completion_generation, ChatGeneration): completion_message = completion_generation.message @@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): if response.llm_output: self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] + else: + self._current_loop.completion_tokens = self.model_instant.get_num_tokens( + [PromptMessage(content=self._current_loop.completion)] + ) def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index df06101e4d..c5ba8e6a7c 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -119,9 +119,11 @@ class ConversationMessageTask: message="", message_tokens=0, message_unit_price=0, + message_price_unit=0, answer="", answer_tokens=0, answer_unit_price=0, + answer_price_unit=0, provider_response_latency=0, total_price=0, currency=self.model_instance.get_currency(), @@ -142,7 +144,9 @@ class ConversationMessageTask: answer_tokens = llm_message.completion_tokens message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) + message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN) answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) + answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) @@ -151,9 +155,11 @@ class ConversationMessageTask: self.message.message = llm_message.prompt self.message.message_tokens = message_tokens self.message.message_unit_price = message_unit_price + self.message.message_price_unit = message_price_unit self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' self.message.answer_tokens = answer_tokens self.message.answer_unit_price = answer_unit_price + self.message.answer_price_unit = answer_price_unit self.message.provider_response_latency = llm_message.latency self.message.total_price = total_price @@ -195,7 +201,9 @@ class ConversationMessageTask: tool=agent_loop.tool_name, tool_input=agent_loop.tool_input, message=agent_loop.prompt, + message_price_unit=0, answer=agent_loop.completion, + answer_price_unit=0, created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), created_by=self.user.id ) @@ -210,7 +218,9 @@ class ConversationMessageTask: def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, agent_loop: AgentLoop): agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) + agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN) agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) + agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT) loop_message_tokens = agent_loop.prompt_tokens loop_answer_tokens = agent_loop.completion_tokens @@ -223,8 +233,10 @@ class ConversationMessageTask: message_agent_thought.tool_process_data = '' # currently not support message_agent_thought.message_token = loop_message_tokens message_agent_thought.message_unit_price = agent_message_unit_price + message_agent_thought.message_price_unit = agent_message_price_unit message_agent_thought.answer_token = loop_answer_tokens message_agent_thought.answer_unit_price = agent_answer_unit_price + message_agent_thought.answer_price_unit = agent_answer_price_unit message_agent_thought.latency = agent_loop.latency message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.total_price = loop_total_price diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 6b20098be3..fe92407069 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel): """ raise NotImplementedError - def calc_tokens_price(self, tokens:int, message_type: MessageType): + def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal: """ calc tokens total price. @@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel): unit_price = self.price_config['prompt'] else: unit_price = self.price_config['completion'] - unit = self.price_config['unit'] + unit = self.get_price_unit(message_type) total_price = tokens * unit_price * unit total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") return total_price - def get_tokens_unit_price(self, message_type: MessageType): + def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal: """ get token price. @@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel): logging.debug(f"unit_price={unit_price}") return unit_price - def get_currency(self): + def get_price_unit(self, message_type: MessageType) -> decimal.Decimal: + """ + get price unit. + + :param message_type: + :return: decimal.Decimal('0.000001') + """ + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + price_unit = self.price_config['unit'] + else: + price_unit = self.price_config['unit'] + + price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP) + logging.debug(f"price_unit={price_unit}") + return price_unit + + def get_currency(self) -> str: """ get token currency. diff --git a/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py new file mode 100644 index 0000000000..f3c13095a6 --- /dev/null +++ b/api/migrations/versions/853f9b9cd3b6_add_message_price_unit.py @@ -0,0 +1,43 @@ +"""add message price unit + +Revision ID: 853f9b9cd3b6 +Revises: e8883b0148c9 +Create Date: 2023-08-19 17:01:57.471562 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '853f9b9cd3b6' +down_revision = 'e8883b0148c9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('answer_price_unit') + batch_op.drop_column('message_price_unit') + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_column('answer_price_unit') + batch_op.drop_column('message_price_unit') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index d11cff2e24..b77363d0b8 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -421,9 +421,11 @@ class Message(db.Model): message = db.Column(db.JSON, nullable=False) message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) answer = db.Column(db.Text, nullable=False) answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) @@ -705,9 +707,11 @@ class MessageAgentThought(db.Model): message = db.Column(db.Text, nullable=True) message_token = db.Column(db.Integer, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) answer = db.Column(db.Text, nullable=True) answer_token = db.Column(db.Integer, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) tokens = db.Column(db.Integer, nullable=True) total_price = db.Column(db.Numeric, nullable=True) currency = db.Column(db.String, nullable=True)