fix: tongyi json output (#5396)

This commit is contained in:
LXM 2024-06-22 12:25:23 +08:00 committed by GitHub
parent 3bbd75f1f2
commit e8ad0339a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 18 deletions

View File

@ -18,7 +18,7 @@ from dashscope.common.error import (
)
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@ -82,6 +82,7 @@ if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
You should also complete the text started with ``` but not tell ``` directly.
"""
code_block = model_parameters.get("response_format", "")
@ -113,21 +114,17 @@ if you are not sure about the structure.
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
.replace("{{instructions}}", f"Please output a valid {code_block} with markdown codeblocks.")
))
mode = self.get_model_mode(model, credentials)
if mode == LLMMode.CHAT:
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
response = self._invoke(
model=model,
@ -243,11 +240,8 @@ if you are not sure about the structure.
response = MultiModalConversation.call(**params, stream=stream)
else:
if mode == LLMMode.CHAT:
params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages)
else:
params['prompt'] = prompt_messages[0].content.rstrip()
# nothing different between chat model and completion model in tongyi
params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages)
response = Generation.call(**params,
result_format='message',
stream=stream)

View File

@ -0,0 +1,84 @@
import json
import os
from collections.abc import Generator
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from core.model_runtime.model_providers.tongyi.llm.llm import TongyiLargeLanguageModel
def test_invoke_model_with_json_response():
"""
Test the invocation of a model with JSON response.
"""
model_list = [
"qwen-max-0403",
"qwen-max-1201",
"qwen-max-longcontext",
"qwen-max",
"qwen-plus-chat",
"qwen-plus",
"qwen-turbo-chat",
"qwen-turbo",
]
for model_name in model_list:
print("testing model: ", model_name)
invoke_model_with_json_response(model_name)
def invoke_model_with_json_response(model_name="qwen-max-0403"):
"""
Method to invoke the model with JSON response format.
Args:
model_name (str): The name of the model to invoke. Defaults to "qwen-max-0403".
Returns:
None
"""
model = TongyiLargeLanguageModel()
response = model.invoke(
model=model_name,
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='output json data with format `{"data": "test", "code": 200, "msg": "success"}'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 50,
'response_format': 'JSON',
},
stream=True,
user="abc-123"
)
print("=====================================")
print(response)
assert isinstance(response, Generator)
output = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
output += chunk.delta.message.content
assert is_json(output)
def is_json(s):
"""
Check if a string is a valid JSON.
Args:
s (str): The string to check.
Returns:
bool: True if the string is a valid JSON, False otherwise.
"""
try:
json.loads(s)
except ValueError:
return False
return True