From 91ee62d1abdaa196ee0c0af23fb3d8d311b9efd4 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Wed, 3 Jan 2024 18:29:44 +0800 Subject: [PATCH] fix: huggingface and replicate. (#1888) --- .../huggingface_hub/llm/llm.py | 35 ++++++++++----- .../model_providers/replicate/llm/llm.py | 44 +++++++++++++------ 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 366950ad84..33df2ec340 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel content=chunk.token.text ) - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + if chunk.details: + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage, - ), - ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + usage=usage, + finish_reason=chunk.details.finish_reason, + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ), + ) def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: if isinstance(response, str): diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 556fab977d..cd750375be 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt']: + if key not in ['system_prompt', 'prompt'] and 'stop' not in key: value_type = value.get('type') if not value_type: @@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): index = -1 current_completion: str = "" stop_condition_reached = False + + prediction_output_length = 10000 + is_prediction_output_finished = False + for output in prediction.output_iterator(): current_completion += output + if not is_prediction_output_finished and prediction.status == 'succeeded': + prediction_output_length = len(prediction.output) - 1 + is_prediction_output_finished = True + if stop: for s in stop: if s in current_completion: @@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): content=output if output else '' ) - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + if index < prediction_output_length: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + else: + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage, - ), - ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + usage=usage + ) + ) def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], prompt_messages: list[PromptMessage]) -> LLMResult: