diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index d01e24e2ac..b8c979b1f5 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -1,5 +1,6 @@ import json import logging +import re from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast @@ -131,115 +132,58 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): """ handle stream chat generate response """ - - class ChunkProcessor: - def __init__(self): - self.buffer = bytearray() - - def try_decode_chunk(self, chunk: bytes) -> list[dict]: - """尝试从chunk中解码出完整的JSON对象""" - self.buffer.extend(chunk) - results = [] - - while True: - try: - start = self.buffer.find(b"{") - if start == -1: - self.buffer.clear() - break - - bracket_count = 0 - end = start - - for i in range(start, len(self.buffer)): - if self.buffer[i] == ord("{"): - bracket_count += 1 - elif self.buffer[i] == ord("}"): - bracket_count -= 1 - if bracket_count == 0: - end = i + 1 - break - - if bracket_count != 0: - # JSON不完整,等待更多数据 - if start > 0: - self.buffer = self.buffer[start:] - break - - json_bytes = self.buffer[start:end] - try: - data = json.loads(json_bytes) - results.append(data) - self.buffer = self.buffer[end:] - except json.JSONDecodeError: - self.buffer = self.buffer[start + 1 :] - - except Exception as e: - logger.debug(f"Warning: Error processing chunk ({str(e)})") - if start > 0: - self.buffer = self.buffer[start:] - break - - return results - full_response = "" - processor = ChunkProcessor() + buffer = "" + for chunk_bytes in resp: + buffer += chunk_bytes.decode("utf-8") + last_idx = 0 + for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): + try: + data = json.loads(match.group(1).strip()) + last_idx = match.span()[1] - try: - for chunk in resp: - json_objects = processor.try_decode_chunk(chunk) + if "content" in data["choices"][0]["delta"]: + chunk_content = data["choices"][0]["delta"]["content"] + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) - for data in json_objects: - if data.get("choices"): - choice = data["choices"][0] + if data["choices"][0]["finish_reason"] is not None: + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) - if "delta" in choice and "content" in choice["delta"]: - chunk_content = choice["delta"]["content"] - assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=data["choices"][0]["finish_reason"], + usage=usage, + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), + ) - if choice.get("finish_reason") is not None: - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, tool_calls=[] - ) + full_response += chunk_content + except (json.JSONDecodeError, KeyError, IndexError) as e: + logger.info("json parse exception, content: {}".format(match.group(1).strip())) + pass - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages( - messages=[temp_assistant_prompt_message], tools=[] - ) - - usage = self._calc_response_usage( - model=model, - credentials=credentials, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=None, - delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=choice["finish_reason"], - usage=usage, - ), - ) - else: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - system_fingerprint=None, - delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), - ) - - full_response += chunk_content - - except Exception as e: - raise - - if not full_response: - logger.warning("No content received from stream response") + buffer = buffer[last_idx:] def _invoke( self,