mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 00:35:56 +08:00
feat: support json schema for gemini models (#10835)
This commit is contained in:
parent
9f195df103
commit
bc1013dacf
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
@ -32,3 +32,4 @@ pricing:
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
@ -36,3 +36,4 @@ pricing:
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
if schema := config_kwargs.pop("json_schema", None):
|
||||
try:
|
||||
schema = json.loads(schema)
|
||||
except:
|
||||
raise exceptions.InvalidArgument("Invalid JSON Schema")
|
||||
if tools:
|
||||
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
|
||||
config_kwargs["response_schema"] = schema
|
||||
config_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="gemini-pro",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
|
||||
]
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
|
||||
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="gemini-pro",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
|
||||
]
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
|
||||
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="gemini-pro-vision",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="gemini-pro-vision",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
@ -164,7 +164,7 @@ def test_get_num_tokens():
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="gemini-pro",
|
||||
model="gemini-1.5-pro",
|
||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
|
Loading…
x
Reference in New Issue
Block a user