mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 23:16:04 +08:00
Feat/json mode (#2563)
This commit is contained in:
parent
0620fa3094
commit
3e63abd335
@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||||||
'min': 1,
|
'min': 1,
|
||||||
'max': 2048,
|
'max': 2048,
|
||||||
'precision': 0,
|
'precision': 0,
|
||||||
|
},
|
||||||
|
DefaultParameterName.RESPONSE_FORMAT: {
|
||||||
|
'label': {
|
||||||
|
'en_US': 'Response Format',
|
||||||
|
'zh_Hans': '回复格式',
|
||||||
|
},
|
||||||
|
'type': 'string',
|
||||||
|
'help': {
|
||||||
|
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
|
||||||
|
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
|
||||||
|
},
|
||||||
|
'required': False,
|
||||||
|
'options': ['JSON', 'XML'],
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
|
|||||||
PRESENCE_PENALTY = "presence_penalty"
|
PRESENCE_PENALTY = "presence_penalty"
|
||||||
FREQUENCY_PENALTY = "frequency_penalty"
|
FREQUENCY_PENALTY = "frequency_penalty"
|
||||||
MAX_TOKENS = "max_tokens"
|
MAX_TOKENS = "max_tokens"
|
||||||
|
RESPONSE_FORMAT = "response_format"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||||
|
@ -262,23 +262,23 @@ class AIModel(ABC):
|
|||||||
try:
|
try:
|
||||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||||
if not parameter_rule.max:
|
if not parameter_rule.max and 'max' in default_parameter_rule:
|
||||||
parameter_rule.max = default_parameter_rule['max']
|
parameter_rule.max = default_parameter_rule['max']
|
||||||
if not parameter_rule.min:
|
if not parameter_rule.min and 'min' in default_parameter_rule:
|
||||||
parameter_rule.min = default_parameter_rule['min']
|
parameter_rule.min = default_parameter_rule['min']
|
||||||
if not parameter_rule.precision:
|
if not parameter_rule.default and 'default' in default_parameter_rule:
|
||||||
parameter_rule.default = default_parameter_rule['default']
|
parameter_rule.default = default_parameter_rule['default']
|
||||||
if not parameter_rule.precision:
|
if not parameter_rule.precision and 'precision' in default_parameter_rule:
|
||||||
parameter_rule.precision = default_parameter_rule['precision']
|
parameter_rule.precision = default_parameter_rule['precision']
|
||||||
if not parameter_rule.required:
|
if not parameter_rule.required and 'required' in default_parameter_rule:
|
||||||
parameter_rule.required = default_parameter_rule['required']
|
parameter_rule.required = default_parameter_rule['required']
|
||||||
if not parameter_rule.help:
|
if not parameter_rule.help and 'help' in default_parameter_rule:
|
||||||
parameter_rule.help = I18nObject(
|
parameter_rule.help = I18nObject(
|
||||||
en_US=default_parameter_rule['help']['en_US'],
|
en_US=default_parameter_rule['help']['en_US'],
|
||||||
)
|
)
|
||||||
if not parameter_rule.help.en_US:
|
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
|
||||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
||||||
if not parameter_rule.help.zh_Hans:
|
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
|
||||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
@ -9,7 +9,13 @@ from typing import Optional, Union
|
|||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
ModelPropertyKey,
|
ModelPropertyKey,
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -74,6 +80,19 @@ class LargeLanguageModel(AIModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if "response_format" in model_parameters:
|
||||||
|
result = self._code_block_mode_wrapper(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
else:
|
||||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._trigger_invoke_error_callbacks(
|
self._trigger_invoke_error_callbacks(
|
||||||
@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param model_parameters: model parameters
|
||||||
|
:param tools: tools for tool calling
|
||||||
|
:param stop: stop words
|
||||||
|
:param stream: is stream response
|
||||||
|
:param user: unique user id
|
||||||
|
:param callbacks: callbacks
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
|
code_block = model_parameters.get("response_format", "")
|
||||||
|
if not code_block:
|
||||||
|
return self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
model_parameters.pop("response_format")
|
||||||
|
stop = stop or []
|
||||||
|
stop.extend(["\n```", "```\n"])
|
||||||
|
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||||
|
|
||||||
|
# check if there is a system message
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
# override the system message
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=block_prompts
|
||||||
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# insert the system message
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
content=block_prompts
|
||||||
|
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||||
|
))
|
||||||
|
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||||
|
# add ```JSON\n to the last message
|
||||||
|
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||||
|
else:
|
||||||
|
# append a user message
|
||||||
|
prompt_messages.append(UserPromptMessage(
|
||||||
|
content=f"```{code_block}\n"
|
||||||
|
))
|
||||||
|
|
||||||
|
response = self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, Generator):
|
||||||
|
first_chunk = next(response)
|
||||||
|
def new_generator():
|
||||||
|
yield first_chunk
|
||||||
|
yield from response
|
||||||
|
|
||||||
|
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
|
||||||
|
return self._code_block_mode_stream_processor_with_backtick(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
input_generator=new_generator()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._code_block_mode_stream_processor(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
input_generator=new_generator()
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||||
|
input_generator: Generator[LLMResultChunk, None, None]
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
|
"""
|
||||||
|
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param input_generator: input generator
|
||||||
|
:return: output generator
|
||||||
|
"""
|
||||||
|
state = "normal"
|
||||||
|
backtick_count = 0
|
||||||
|
for piece in input_generator:
|
||||||
|
if piece.delta.message.content:
|
||||||
|
content = piece.delta.message.content
|
||||||
|
piece.delta.message.content = ""
|
||||||
|
yield piece
|
||||||
|
piece = content
|
||||||
|
else:
|
||||||
|
yield piece
|
||||||
|
continue
|
||||||
|
new_piece = ""
|
||||||
|
for char in piece:
|
||||||
|
if state == "normal":
|
||||||
|
if char == "`":
|
||||||
|
state = "in_backticks"
|
||||||
|
backtick_count = 1
|
||||||
|
else:
|
||||||
|
new_piece += char
|
||||||
|
elif state == "in_backticks":
|
||||||
|
if char == "`":
|
||||||
|
backtick_count += 1
|
||||||
|
if backtick_count == 3:
|
||||||
|
state = "skip_content"
|
||||||
|
backtick_count = 0
|
||||||
|
else:
|
||||||
|
new_piece += "`" * backtick_count + char
|
||||||
|
state = "normal"
|
||||||
|
backtick_count = 0
|
||||||
|
elif state == "skip_content":
|
||||||
|
if char.isspace():
|
||||||
|
state = "normal"
|
||||||
|
|
||||||
|
if new_piece:
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=new_piece,
|
||||||
|
tool_calls=[]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||||
|
input_generator: Generator[LLMResultChunk, None, None]) \
|
||||||
|
-> Generator[LLMResultChunk, None, None]:
|
||||||
|
"""
|
||||||
|
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
||||||
|
This version skips the language identifier that follows the opening triple backticks.
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param prompt_messages: prompt messages
|
||||||
|
:param input_generator: input generator
|
||||||
|
:return: output generator
|
||||||
|
"""
|
||||||
|
state = "search_start"
|
||||||
|
backtick_count = 0
|
||||||
|
|
||||||
|
for piece in input_generator:
|
||||||
|
if piece.delta.message.content:
|
||||||
|
content = piece.delta.message.content
|
||||||
|
# Reset content to ensure we're only processing and yielding the relevant parts
|
||||||
|
piece.delta.message.content = ""
|
||||||
|
# Yield a piece with cleared content before processing it to maintain the generator structure
|
||||||
|
yield piece
|
||||||
|
piece = content
|
||||||
|
else:
|
||||||
|
# Yield pieces without content directly
|
||||||
|
yield piece
|
||||||
|
continue
|
||||||
|
|
||||||
|
if state == "done":
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_piece = ""
|
||||||
|
for char in piece:
|
||||||
|
if state == "search_start":
|
||||||
|
if char == "`":
|
||||||
|
backtick_count += 1
|
||||||
|
if backtick_count == 3:
|
||||||
|
state = "skip_language"
|
||||||
|
backtick_count = 0
|
||||||
|
else:
|
||||||
|
backtick_count = 0
|
||||||
|
elif state == "skip_language":
|
||||||
|
# Skip everything until the first newline, marking the end of the language identifier
|
||||||
|
if char == "\n":
|
||||||
|
state = "in_code_block"
|
||||||
|
elif state == "in_code_block":
|
||||||
|
if char == "`":
|
||||||
|
backtick_count += 1
|
||||||
|
if backtick_count == 3:
|
||||||
|
state = "done"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if backtick_count > 0:
|
||||||
|
# If backticks were counted but we're still collecting content, it was a false start
|
||||||
|
new_piece += "`" * backtick_count
|
||||||
|
backtick_count = 0
|
||||||
|
new_piece += char
|
||||||
|
|
||||||
|
elif state == "done":
|
||||||
|
break
|
||||||
|
|
||||||
|
if new_piece:
|
||||||
|
# Only yield content collected within the code block
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=new_piece,
|
||||||
|
tool_calls=[]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
@ -27,6 +27,8 @@ parameter_rules:
|
|||||||
default: 4096
|
default: 4096
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '8.00'
|
input: '8.00'
|
||||||
output: '24.00'
|
output: '24.00'
|
||||||
|
@ -27,6 +27,8 @@ parameter_rules:
|
|||||||
default: 4096
|
default: 4096
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '8.00'
|
input: '8.00'
|
||||||
output: '24.00'
|
output: '24.00'
|
||||||
|
@ -26,6 +26,8 @@ parameter_rules:
|
|||||||
default: 4096
|
default: 4096
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '1.63'
|
input: '1.63'
|
||||||
output: '5.51'
|
output: '5.51'
|
||||||
|
@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
|
|||||||
from anthropic.types import Completion, completion_create_params
|
from anthropic.types import Completion, completion_create_params
|
||||||
from httpx import Timeout
|
from httpx import Timeout
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
|
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
@ -49,6 +57,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper for invoking large language model
|
||||||
|
"""
|
||||||
|
if 'response_format' in model_parameters and model_parameters['response_format']:
|
||||||
|
stop = stop or []
|
||||||
|
self._transform_json_prompts(
|
||||||
|
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
|
||||||
|
)
|
||||||
|
model_parameters.pop('response_format')
|
||||||
|
|
||||||
|
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
def _transform_json_prompts(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||||
|
-> None:
|
||||||
|
"""
|
||||||
|
Transform json prompts
|
||||||
|
"""
|
||||||
|
if "```\n" not in stop:
|
||||||
|
stop.append("```\n")
|
||||||
|
|
||||||
|
# check if there is a system message
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
# override the system message
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# insert the system message
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
))
|
||||||
|
|
||||||
|
prompt_messages.append(AssistantPromptMessage(
|
||||||
|
content=f"```{response_format}\n"
|
||||||
|
))
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -27,6 +27,8 @@ parameter_rules:
|
|||||||
default: 2048
|
default: 2048
|
||||||
min: 1
|
min: 1
|
||||||
max: 2048
|
max: 2048
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
|
@ -24,6 +24,18 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.0005'
|
input: '0.0005'
|
||||||
output: '0.0015'
|
output: '0.0015'
|
||||||
|
@ -24,6 +24,8 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.0015'
|
input: '0.0015'
|
||||||
output: '0.002'
|
output: '0.002'
|
||||||
|
@ -24,6 +24,18 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.001'
|
input: '0.001'
|
||||||
output: '0.002'
|
output: '0.002'
|
||||||
|
@ -24,6 +24,8 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 16385
|
max: 16385
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.004'
|
output: '0.004'
|
||||||
|
@ -24,6 +24,8 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 16385
|
max: 16385
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.004'
|
output: '0.004'
|
||||||
|
@ -21,6 +21,8 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.0015'
|
input: '0.0015'
|
||||||
output: '0.002'
|
output: '0.002'
|
||||||
|
@ -24,6 +24,18 @@ parameter_rules:
|
|||||||
default: 512
|
default: 512
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.001'
|
input: '0.001'
|
||||||
output: '0.002'
|
output: '0.002'
|
||||||
|
@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|||||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
user=user
|
user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper for invoking large language model
|
||||||
|
"""
|
||||||
|
# handle fine tune remote models
|
||||||
|
base_model = model
|
||||||
|
if model.startswith('ft:'):
|
||||||
|
base_model = model.split(':')[1]
|
||||||
|
|
||||||
|
# get model mode
|
||||||
|
model_mode = self.get_model_mode(base_model, credentials)
|
||||||
|
|
||||||
|
# transform response format
|
||||||
|
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
|
||||||
|
stop = stop or []
|
||||||
|
if model_mode == LLMMode.CHAT:
|
||||||
|
# chat model
|
||||||
|
self._transform_chat_json_prompts(
|
||||||
|
model=base_model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
response_format=model_parameters['response_format']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._transform_completion_json_prompts(
|
||||||
|
model=base_model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
response_format=model_parameters['response_format']
|
||||||
|
)
|
||||||
|
model_parameters.pop('response_format')
|
||||||
|
|
||||||
|
return self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||||
|
-> None:
|
||||||
|
"""
|
||||||
|
Transform json prompts
|
||||||
|
"""
|
||||||
|
if "```\n" not in stop:
|
||||||
|
stop.append("```\n")
|
||||||
|
if "\n```" not in stop:
|
||||||
|
stop.append("\n```")
|
||||||
|
|
||||||
|
# check if there is a system message
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
# override the system message
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=OPENAI_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
|
||||||
|
else:
|
||||||
|
# insert the system message
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
content=OPENAI_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
))
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||||
|
|
||||||
|
def _transform_completion_json_prompts(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||||
|
-> None:
|
||||||
|
"""
|
||||||
|
Transform json prompts
|
||||||
|
"""
|
||||||
|
if "```\n" not in stop:
|
||||||
|
stop.append("```\n")
|
||||||
|
if "\n```" not in stop:
|
||||||
|
stop.append("\n```")
|
||||||
|
|
||||||
|
# override the last user message
|
||||||
|
user_message = None
|
||||||
|
for i in range(len(prompt_messages) - 1, -1, -1):
|
||||||
|
if isinstance(prompt_messages[i], UserPromptMessage):
|
||||||
|
user_message = prompt_messages[i]
|
||||||
|
break
|
||||||
|
|
||||||
|
if user_message:
|
||||||
|
if prompt_messages[i].content[-11:] == 'Assistant: ':
|
||||||
|
# now we are in the chat app, remove the last assistant message
|
||||||
|
prompt_messages[i].content = prompt_messages[i].content[:-11]
|
||||||
|
prompt_messages[i] = UserPromptMessage(
|
||||||
|
content=OPENAI_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", user_message.content)
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
|
||||||
|
else:
|
||||||
|
prompt_messages[i] = UserPromptMessage(
|
||||||
|
content=OPENAI_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", user_message.content)
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
prompt_messages[i].content += f"\n```{response_format}\n"
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -13,6 +13,7 @@ from dashscope.common.error import (
|
|||||||
)
|
)
|
||||||
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
|
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -58,6 +59,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
|
||||||
|
-> LLMResult | Generator:
|
||||||
|
"""
|
||||||
|
Wrapper for code block mode
|
||||||
|
"""
|
||||||
|
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
|
code_block = model_parameters.get("response_format", "")
|
||||||
|
if not code_block:
|
||||||
|
return self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
model_parameters.pop("response_format")
|
||||||
|
stop = stop or []
|
||||||
|
stop.extend(["\n```", "```\n"])
|
||||||
|
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||||
|
|
||||||
|
# check if there is a system message
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
# override the system message
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=block_prompts
|
||||||
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# insert the system message
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
content=block_prompts
|
||||||
|
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||||
|
))
|
||||||
|
|
||||||
|
mode = self.get_model_mode(model, credentials)
|
||||||
|
if mode == LLMMode.CHAT:
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||||
|
# add ```JSON\n to the last message
|
||||||
|
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||||
|
else:
|
||||||
|
# append a user message
|
||||||
|
prompt_messages.append(UserPromptMessage(
|
||||||
|
content=f"```{code_block}\n"
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
|
||||||
|
|
||||||
|
response = self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, Generator):
|
||||||
|
return self._code_block_mode_stream_processor_with_backtick(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
input_generator=response
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
"""
|
"""
|
||||||
@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop_sequences'] = stop
|
extra_model_kwargs['stop'] = stop
|
||||||
|
|
||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
params = {
|
params = {
|
||||||
'model': model,
|
'model': model,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**credentials_kwargs
|
**credentials_kwargs,
|
||||||
|
**extra_model_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
mode = self.get_model_mode(model, credentials)
|
mode = self.get_model_mode(model, credentials)
|
||||||
|
@ -57,3 +57,5 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||||
required: false
|
required: false
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
@ -57,3 +57,5 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||||
required: false
|
required: false
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
@ -57,3 +57,5 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||||
required: false
|
required: false
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
@ -56,6 +56,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.02'
|
input: '0.02'
|
||||||
output: '0.02'
|
output: '0.02'
|
||||||
|
@ -57,6 +57,8 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||||
required: false
|
required: false
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.008'
|
input: '0.008'
|
||||||
output: '0.008'
|
output: '0.008'
|
||||||
|
@ -25,6 +25,8 @@ parameter_rules:
|
|||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
- name: disable_search
|
- name: disable_search
|
||||||
label:
|
label:
|
||||||
zh_Hans: 禁用搜索
|
zh_Hans: 禁用搜索
|
||||||
|
@ -25,6 +25,8 @@ parameter_rules:
|
|||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
- name: disable_search
|
- name: disable_search
|
||||||
label:
|
label:
|
||||||
zh_Hans: 禁用搜索
|
zh_Hans: 禁用搜索
|
||||||
|
@ -25,3 +25,5 @@ parameter_rules:
|
|||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
@ -34,3 +34,5 @@ parameter_rules:
|
|||||||
zh_Hans: 禁用模型自行进行外部搜索。
|
zh_Hans: 禁用模型自行进行外部搜索。
|
||||||
en_US: Disable the model to perform external search.
|
en_US: Disable the model to perform external search.
|
||||||
required: false
|
required: false
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
|||||||
RateLimitReachedError,
|
RateLimitReachedError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
class ErnieBotLarguageModel(LargeLanguageModel):
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
|
||||||
|
You should also complete the text started with ``` but not tell ``` directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
|
|||||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper for invoking large language model
|
||||||
|
"""
|
||||||
|
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
|
||||||
|
response_format = model_parameters['response_format']
|
||||||
|
stop = stop or []
|
||||||
|
self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
|
||||||
|
model_parameters.pop('response_format')
|
||||||
|
if stream:
|
||||||
|
return self._code_block_mode_stream_processor(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
def _transform_json_prompts(self, model: str, credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||||
|
-> None:
|
||||||
|
"""
|
||||||
|
Transform json prompts to model prompts
|
||||||
|
"""
|
||||||
|
|
||||||
|
# check if there is a system message
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
# override the system message
|
||||||
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
|
content=ERNIE_BOT_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# insert the system message
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
content=ERNIE_BOT_BLOCK_MODE_PROMPT
|
||||||
|
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||||
|
.replace("{{block}}", response_format)
|
||||||
|
))
|
||||||
|
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||||
|
# add ```JSON\n to the last message
|
||||||
|
prompt_messages[-1].content += "\n```JSON\n{\n"
|
||||||
|
else:
|
||||||
|
# append a user message
|
||||||
|
prompt_messages.append(UserPromptMessage(
|
||||||
|
content="```JSON\n{\n"
|
||||||
|
))
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
tools: list[PromptMessageTool] | None = None) -> int:
|
||||||
# tools is not supported yet
|
# tools is not supported yet
|
||||||
|
@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
|
|||||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from core.model_runtime.utils import helper
|
from core.model_runtime.utils import helper
|
||||||
|
|
||||||
|
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||||
|
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
And you should always end the block with a "```" to indicate the end of the JSON object.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
|
||||||
|
```JSON"""
|
||||||
|
|
||||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
|
|
||||||
@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
|
||||||
# invoke model
|
# invoke model
|
||||||
|
# stop = stop or []
|
||||||
|
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
|
# def _transform_json_prompts(self, model: str, credentials: dict,
|
||||||
|
# prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||||
|
# stream: bool = True, user: str | None = None) \
|
||||||
|
# -> None:
|
||||||
|
# """
|
||||||
|
# Transform json prompts to model prompts
|
||||||
|
# """
|
||||||
|
# if "}\n\n" not in stop:
|
||||||
|
# stop.append("}\n\n")
|
||||||
|
|
||||||
|
# # check if there is a system message
|
||||||
|
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
# # override the system message
|
||||||
|
# prompt_messages[0] = SystemPromptMessage(
|
||||||
|
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# # insert the system message
|
||||||
|
# prompt_messages.insert(0, SystemPromptMessage(
|
||||||
|
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
|
||||||
|
# ))
|
||||||
|
# # check if the last message is a user message
|
||||||
|
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||||
|
# # add ```JSON\n to the last message
|
||||||
|
# prompt_messages[-1].content += "\n```JSON\n"
|
||||||
|
# else:
|
||||||
|
# # append a user message
|
||||||
|
# prompt_messages.append(UserPromptMessage(
|
||||||
|
# content="```JSON\n"
|
||||||
|
# ))
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
"""
|
"""
|
||||||
@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop_sequences'] = stop
|
extra_model_kwargs['stop'] = stop
|
||||||
|
|
||||||
client = ZhipuAI(
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs['api_key']
|
api_key=credentials_kwargs['api_key']
|
||||||
@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
response = client.chat.completions.create(stream=stream, **params)
|
response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
|
||||||
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||||
|
|
||||||
response = client.chat.completions.create(**params)
|
response = client.chat.completions.create(**params, **extra_model_kwargs)
|
||||||
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str,
|
def _handle_generate_response(self, model: str,
|
||||||
|
@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|||||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel
|
from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
def test_predefined_models():
|
def test_predefined_models():
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
model_schemas = model.predefined_models()
|
model_schemas = model.predefined_models()
|
||||||
assert len(model_schemas) >= 1
|
assert len(model_schemas) >= 1
|
||||||
assert isinstance(model_schemas[0], AIModelEntity)
|
assert isinstance(model_schemas[0], AIModelEntity)
|
||||||
|
|
||||||
def test_validate_credentials_for_chat_model():
|
def test_validate_credentials_for_chat_model():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
with pytest.raises(CredentialsValidateFailedError):
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
model.validate_credentials(
|
model.validate_credentials(
|
||||||
@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model():
|
|||||||
|
|
||||||
def test_invoke_model_ernie_bot():
|
def test_invoke_model_ernie_bot():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot',
|
model='ernie-bot',
|
||||||
@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot():
|
|||||||
|
|
||||||
def test_invoke_model_ernie_bot_turbo():
|
def test_invoke_model_ernie_bot_turbo():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot-turbo',
|
model='ernie-bot-turbo',
|
||||||
@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo():
|
|||||||
|
|
||||||
def test_invoke_model_ernie_8k():
|
def test_invoke_model_ernie_8k():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot-8k',
|
model='ernie-bot-8k',
|
||||||
@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k():
|
|||||||
|
|
||||||
def test_invoke_model_ernie_bot_4():
|
def test_invoke_model_ernie_bot_4():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot-4',
|
model='ernie-bot-4',
|
||||||
@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4():
|
|||||||
|
|
||||||
def test_invoke_stream_model():
|
def test_invoke_stream_model():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot',
|
model='ernie-bot',
|
||||||
@ -182,7 +182,7 @@ def test_invoke_stream_model():
|
|||||||
|
|
||||||
def test_invoke_model_with_system():
|
def test_invoke_model_with_system():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot',
|
model='ernie-bot',
|
||||||
@ -212,7 +212,7 @@ def test_invoke_model_with_system():
|
|||||||
|
|
||||||
def test_invoke_with_search():
|
def test_invoke_with_search():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model='ernie-bot',
|
model='ernie-bot',
|
||||||
@ -250,7 +250,7 @@ def test_invoke_with_search():
|
|||||||
|
|
||||||
def test_get_num_tokens():
|
def test_get_num_tokens():
|
||||||
sleep(3)
|
sleep(3)
|
||||||
model = ErnieBotLarguageModel()
|
model = ErnieBotLargeLanguageModel()
|
||||||
|
|
||||||
response = model.get_num_tokens(
|
response = model.get_num_tokens(
|
||||||
model='ernie-bot',
|
model='ernie-bot',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user