fix(api/core/model_runtime/model_providers/__base/large_language_model.py): Add TEXT type checker (#7407)

This commit is contained in:
-LAN- 2024-08-19 18:45:30 +08:00 committed by GitHub
parent bd07e1d2fd
commit 0087afc2e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -185,7 +185,7 @@ if you are not sure about the structure.
stream=stream, stream=stream,
user=user user=user
) )
model_parameters.pop("response_format") model_parameters.pop("response_format")
stop = stop or [] stop = stop or []
stop.extend(["\n```", "```\n"]) stop.extend(["\n```", "```\n"])
@ -249,10 +249,10 @@ if you are not sure about the structure.
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
input_generator=new_generator() input_generator=new_generator()
) )
return response return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None] input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]: ) -> Generator[LLMResultChunk, None, None]:
""" """
@ -310,7 +310,7 @@ if you are not sure about the structure.
) )
) )
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \ input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]: -> Generator[LLMResultChunk, None, None]:
""" """
@ -470,7 +470,7 @@ if you are not sure about the structure.
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -792,6 +792,13 @@ if you are not sure about the structure.
if not isinstance(parameter_value, str): if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.") raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
elif parameter_rule.type == ParameterType.TEXT:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be text.")
# validate options # validate options
if parameter_rule.options and parameter_value not in parameter_rule.options: if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")