diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 58b5407c02..c4af958a3c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -416,7 +416,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if chunk.startswith(':'): continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - chunk_json = None + try: chunk_json = json.loads(decoded_chunk) # stream ended @@ -620,7 +620,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: str, + def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -630,7 +630,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ - num_tokens = self._get_num_tokens_by_gpt2(text) + if isinstance(text, str): + full_text = text + else: + full_text = '' + for message_content in text: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + full_text += message_content.data + + num_tokens = self._get_num_tokens_by_gpt2(full_text) if tools: num_tokens += self._num_tokens_for_tools(tools)