feat: support json schema for gemini models (#10835)

This commit is contained in:
非法操作 2024-11-19 17:49:58 +08:00 committed by GitHub
parent 9f195df103
commit bc1013dacf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 61 additions and 77 deletions

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token. en_US: Only sample from the top K options for each subsequent token.
required: false required: false
- name: max_tokens_to_sample - name: max_output_tokens
use_template: max_tokens use_template: max_tokens
required: true
default: 8192 default: 8192
min: 1 min: 1
max: 8192 max: 8192
- name: response_format - name: json_schema
use_template: response_format use_template: json_schema
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'

View File

@ -32,3 +32,4 @@ pricing:
output: '0.00' output: '0.00'
unit: '0.000001' unit: '0.000001'
currency: USD currency: USD
deprecated: true

View File

@ -36,3 +36,4 @@ pricing:
output: '0.00' output: '0.00'
unit: '0.000001' unit: '0.000001'
currency: USD currency: USD
deprecated: true

View File

@ -1,7 +1,6 @@
import base64 import base64
import io import io
import json import json
import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke( def _invoke(
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
try: try:
ping_message = SystemPromptMessage(content="ping") ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
config_kwargs = model_parameters.copy() config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if schema := config_kwargs.pop("json_schema", None):
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"
if stop: if stop:
config_kwargs["stop_sequences"] = stop config_kwargs["stop_sequences"] = stop

View File

@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
response = model.invoke( response = model.invoke(
model="gemini-pro", model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[ prompt_messages=[
SystemPromptMessage( SystemPromptMessage(
@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
] ]
), ),
], ],
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048}, model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
stop=["How"], stop=["How"],
stream=False, stream=False,
user="abc-123", user="abc-123",
@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
response = model.invoke( response = model.invoke(
model="gemini-pro", model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[ prompt_messages=[
SystemPromptMessage( SystemPromptMessage(
@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
] ]
), ),
], ],
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048}, model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
stream=True, stream=True,
user="abc-123", user="abc-123",
) )
@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
result = model.invoke( result = model.invoke(
model="gemini-pro-vision", model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[ prompt_messages=[
SystemPromptMessage( SystemPromptMessage(
@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
result = model.invoke( result = model.invoke(
model="gemini-pro-vision", model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[ prompt_messages=[
SystemPromptMessage(content="You are a helpful AI assistant."), SystemPromptMessage(content="You are a helpful AI assistant."),
@ -164,7 +164,7 @@ def test_get_num_tokens():
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
num_tokens = model.get_num_tokens( num_tokens = model.get_num_tokens(
model="gemini-pro", model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")}, credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[ prompt_messages=[
SystemPromptMessage( SystemPromptMessage(