From 9954ddb780622e7b1cbd2f8baf44b1df2754d0b3 Mon Sep 17 00:00:00 2001 From: Warren Chen Date: Thu, 2 Jan 2025 09:49:11 +0800 Subject: [PATCH] [Fix] modify sagemaker llm (#12274) --- .../model_providers/sagemaker/llm/llm.py | 150 ++++++++++++------ 1 file changed, 103 insertions(+), 47 deletions(-) diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index b8c979b1f5..d01e24e2ac 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -1,6 +1,5 @@ import json import logging -import re from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast @@ -132,58 +131,115 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): """ handle stream chat generate response """ + + class ChunkProcessor: + def __init__(self): + self.buffer = bytearray() + + def try_decode_chunk(self, chunk: bytes) -> list[dict]: + """尝试从chunk中解码出完整的JSON对象""" + self.buffer.extend(chunk) + results = [] + + while True: + try: + start = self.buffer.find(b"{") + if start == -1: + self.buffer.clear() + break + + bracket_count = 0 + end = start + + for i in range(start, len(self.buffer)): + if self.buffer[i] == ord("{"): + bracket_count += 1 + elif self.buffer[i] == ord("}"): + bracket_count -= 1 + if bracket_count == 0: + end = i + 1 + break + + if bracket_count != 0: + # JSON不完整,等待更多数据 + if start > 0: + self.buffer = self.buffer[start:] + break + + json_bytes = self.buffer[start:end] + try: + data = json.loads(json_bytes) + results.append(data) + self.buffer = self.buffer[end:] + except json.JSONDecodeError: + self.buffer = self.buffer[start + 1 :] + + except Exception as e: + logger.debug(f"Warning: Error processing chunk ({str(e)})") + if start > 0: + self.buffer = self.buffer[start:] + break + + return results + full_response = "" - buffer = "" - for chunk_bytes in resp: - buffer += chunk_bytes.decode("utf-8") - last_idx = 0 - for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): - try: - data = json.loads(match.group(1).strip()) - last_idx = match.span()[1] + processor = ChunkProcessor() - if "content" in data["choices"][0]["delta"]: - chunk_content = data["choices"][0]["delta"]["content"] - assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) + try: + for chunk in resp: + json_objects = processor.try_decode_chunk(chunk) - if data["choices"][0]["finish_reason"] is not None: - temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages( - messages=[temp_assistant_prompt_message], tools=[] - ) - usage = self._calc_response_usage( - model=model, - credentials=credentials, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) + for data in json_objects: + if data.get("choices"): + choice = data["choices"][0] - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=None, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=data["choices"][0]["finish_reason"], - usage=usage, - ), - ) - else: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=None, - delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), - ) + if "delta" in choice and "content" in choice["delta"]: + chunk_content = choice["delta"]["content"] + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) - full_response += chunk_content - except (json.JSONDecodeError, KeyError, IndexError) as e: - logger.info("json parse exception, content: {}".format(match.group(1).strip())) - pass + if choice.get("finish_reason") is not None: + temp_assistant_prompt_message = AssistantPromptMessage( + content=full_response, tool_calls=[] + ) - buffer = buffer[last_idx:] + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=choice["finish_reason"], + usage=usage, + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), + ) + + full_response += chunk_content + + except Exception as e: + raise + + if not full_response: + logger.warning("No content received from stream response") def _invoke( self,