[Fix] modify sagemaker llm (#12274)

This commit is contained in:
Warren Chen 2025-01-02 09:49:11 +08:00 committed by GitHub
parent b218df6920
commit 9954ddb780
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,5 @@
import json import json
import logging import logging
import re
from collections.abc import Generator, Iterator from collections.abc import Generator, Iterator
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -132,58 +131,115 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
""" """
handle stream chat generate response handle stream chat generate response
""" """
class ChunkProcessor:
def __init__(self):
self.buffer = bytearray()
def try_decode_chunk(self, chunk: bytes) -> list[dict]:
"""尝试从chunk中解码出完整的JSON对象"""
self.buffer.extend(chunk)
results = []
while True:
try:
start = self.buffer.find(b"{")
if start == -1:
self.buffer.clear()
break
bracket_count = 0
end = start
for i in range(start, len(self.buffer)):
if self.buffer[i] == ord("{"):
bracket_count += 1
elif self.buffer[i] == ord("}"):
bracket_count -= 1
if bracket_count == 0:
end = i + 1
break
if bracket_count != 0:
# JSON不完整等待更多数据
if start > 0:
self.buffer = self.buffer[start:]
break
json_bytes = self.buffer[start:end]
try:
data = json.loads(json_bytes)
results.append(data)
self.buffer = self.buffer[end:]
except json.JSONDecodeError:
self.buffer = self.buffer[start + 1 :]
except Exception as e:
logger.debug(f"Warning: Error processing chunk ({str(e)})")
if start > 0:
self.buffer = self.buffer[start:]
break
return results
full_response = "" full_response = ""
buffer = "" processor = ChunkProcessor()
for chunk_bytes in resp:
buffer += chunk_bytes.decode("utf-8")
last_idx = 0
for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer):
try:
data = json.loads(match.group(1).strip())
last_idx = match.span()[1]
if "content" in data["choices"][0]["delta"]: try:
chunk_content = data["choices"][0]["delta"]["content"] for chunk in resp:
assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) json_objects = processor.try_decode_chunk(chunk)
if data["choices"][0]["finish_reason"] is not None: for data in json_objects:
temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) if data.get("choices"):
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) choice = data["choices"][0]
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield LLMResultChunk( if "delta" in choice and "content" in choice["delta"]:
model=model, chunk_content = choice["delta"]["content"]
prompt_messages=prompt_messages, assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[])
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=data["choices"][0]["finish_reason"],
usage=usage,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
)
full_response += chunk_content if choice.get("finish_reason") is not None:
except (json.JSONDecodeError, KeyError, IndexError) as e: temp_assistant_prompt_message = AssistantPromptMessage(
logger.info("json parse exception, content: {}".format(match.group(1).strip())) content=full_response, tool_calls=[]
pass )
buffer = buffer[last_idx:] prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(
index=0,
message=assistant_prompt_message,
finish_reason=choice["finish_reason"],
usage=usage,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=None,
delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message),
)
full_response += chunk_content
except Exception as e:
raise
if not full_response:
logger.warning("No content received from stream response")
def _invoke( def _invoke(
self, self,