mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 16:45:57 +08:00
feat: support json schema for gemini models (#10835)
This commit is contained in:
parent
9f195df103
commit
bc1013dacf
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -24,14 +24,13 @@ parameter_rules:
|
|||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
required: false
|
required: false
|
||||||
- name: max_tokens_to_sample
|
- name: max_output_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
|
||||||
default: 8192
|
default: 8192
|
||||||
min: 1
|
min: 1
|
||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: json_schema
|
||||||
use_template: response_format
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,3 +32,4 @@ pricing:
|
|||||||
output: '0.00'
|
output: '0.00'
|
||||||
unit: '0.000001'
|
unit: '0.000001'
|
||||||
currency: USD
|
currency: USD
|
||||||
|
deprecated: true
|
||||||
|
@ -36,3 +36,4 @@ pricing:
|
|||||||
output: '0.00'
|
output: '0.00'
|
||||||
unit: '0.000001'
|
unit: '0.000001'
|
||||||
currency: USD
|
currency: USD
|
||||||
|
deprecated: true
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
|
||||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
|
||||||
if you are not sure about the structure.
|
|
||||||
|
|
||||||
<instructions>
|
|
||||||
{{instructions}}
|
|
||||||
</instructions>
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
def _invoke(
|
def _invoke(
|
||||||
@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
ping_message = SystemPromptMessage(content="ping")
|
ping_message = SystemPromptMessage(content="ping")
|
||||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
config_kwargs = model_parameters.copy()
|
config_kwargs = model_parameters.copy()
|
||||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
if schema := config_kwargs.pop("json_schema", None):
|
||||||
|
try:
|
||||||
|
schema = json.loads(schema)
|
||||||
|
except:
|
||||||
|
raise exceptions.InvalidArgument("Invalid JSON Schema")
|
||||||
|
if tools:
|
||||||
|
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
|
||||||
|
config_kwargs["response_schema"] = schema
|
||||||
|
config_kwargs["response_mime_type"] = "application/json"
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
config_kwargs["stop_sequences"] = stop
|
config_kwargs["stop_sequences"] = stop
|
||||||
|
@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
|
|||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model="gemini-pro",
|
model="gemini-1.5-pro",
|
||||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||||
prompt_messages=[
|
prompt_messages=[
|
||||||
SystemPromptMessage(
|
SystemPromptMessage(
|
||||||
@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
|
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
|
||||||
stop=["How"],
|
stop=["How"],
|
||||||
stream=False,
|
stream=False,
|
||||||
user="abc-123",
|
user="abc-123",
|
||||||
@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
|
|||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
response = model.invoke(
|
response = model.invoke(
|
||||||
model="gemini-pro",
|
model="gemini-1.5-pro",
|
||||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||||
prompt_messages=[
|
prompt_messages=[
|
||||||
SystemPromptMessage(
|
SystemPromptMessage(
|
||||||
@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
|
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
|
||||||
stream=True,
|
stream=True,
|
||||||
user="abc-123",
|
user="abc-123",
|
||||||
)
|
)
|
||||||
@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
|||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
model="gemini-pro-vision",
|
model="gemini-1.5-pro",
|
||||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||||
prompt_messages=[
|
prompt_messages=[
|
||||||
SystemPromptMessage(
|
SystemPromptMessage(
|
||||||
@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
|||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
model="gemini-pro-vision",
|
model="gemini-1.5-pro",
|
||||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||||
prompt_messages=[
|
prompt_messages=[
|
||||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||||
@ -164,7 +164,7 @@ def test_get_num_tokens():
|
|||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
num_tokens = model.get_num_tokens(
|
num_tokens = model.get_num_tokens(
|
||||||
model="gemini-pro",
|
model="gemini-1.5-pro",
|
||||||
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
|
||||||
prompt_messages=[
|
prompt_messages=[
|
||||||
SystemPromptMessage(
|
SystemPromptMessage(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user