mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 23:06:15 +08:00
Merge branch 'main' into feat/plugins
This commit is contained in:
commit
9a242bcac9
7
.github/workflows/api-tests.yml
vendored
7
.github/workflows/api-tests.yml
vendored
@ -27,18 +27,17 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: abatilo/actions-poetry@v3
|
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: 'poetry'
|
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
api/pyproject.toml
|
api/pyproject.toml
|
||||||
api/poetry.lock
|
api/poetry.lock
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: abatilo/actions-poetry@v3
|
||||||
|
|
||||||
- name: Check Poetry lockfile
|
- name: Check Poetry lockfile
|
||||||
run: |
|
run: |
|
||||||
poetry check -C api --lock
|
poetry check -C api --lock
|
||||||
|
7
.github/workflows/db-migration-test.yml
vendored
7
.github/workflows/db-migration-test.yml
vendored
@ -23,18 +23,17 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: abatilo/actions-poetry@v3
|
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: 'poetry'
|
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
api/pyproject.toml
|
api/pyproject.toml
|
||||||
api/poetry.lock
|
api/poetry.lock
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: abatilo/actions-poetry@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: poetry install -C api
|
run: poetry install -C api
|
||||||
|
|
||||||
|
7
.github/workflows/style.yml
vendored
7
.github/workflows/style.yml
vendored
@ -24,15 +24,16 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
files: api/**
|
files: api/**
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: abatilo/actions-poetry@v3
|
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
uses: abatilo/actions-poetry@v3
|
||||||
|
|
||||||
- name: Python dependencies
|
- name: Python dependencies
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: poetry install -C api --only lint
|
run: poetry install -C api --only lint
|
||||||
|
@ -85,3 +85,4 @@
|
|||||||
cd ../
|
cd ../
|
||||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
205
api/app.py
205
api/app.py
@ -10,44 +10,19 @@ if os.environ.get("DEBUG", "false").lower() != "true":
|
|||||||
grpc.experimental.gevent.init_gevent()
|
grpc.experimental.gevent.init_gevent()
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
|
|
||||||
from flask import Flask, Response, request
|
from flask import Response
|
||||||
from flask_cors import CORS
|
|
||||||
from werkzeug.exceptions import Unauthorized
|
|
||||||
|
|
||||||
import contexts
|
from app_factory import create_app
|
||||||
from commands import register_commands
|
|
||||||
from configs import dify_config
|
|
||||||
|
|
||||||
# DO NOT REMOVE BELOW
|
# DO NOT REMOVE BELOW
|
||||||
from events import event_handlers # noqa: F401
|
from events import event_handlers # noqa: F401
|
||||||
from extensions import (
|
|
||||||
ext_celery,
|
|
||||||
ext_code_based_extension,
|
|
||||||
ext_compress,
|
|
||||||
ext_database,
|
|
||||||
ext_hosting_provider,
|
|
||||||
ext_login,
|
|
||||||
ext_mail,
|
|
||||||
ext_migrate,
|
|
||||||
ext_proxy_fix,
|
|
||||||
ext_redis,
|
|
||||||
ext_sentry,
|
|
||||||
ext_storage,
|
|
||||||
)
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from extensions.ext_login import login_manager
|
|
||||||
from libs.passport import PassportService
|
|
||||||
|
|
||||||
# TODO: Find a way to avoid importing models here
|
# TODO: Find a way to avoid importing models here
|
||||||
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
|
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
|
||||||
from services.account_service import AccountService
|
|
||||||
|
|
||||||
# DO NOT REMOVE ABOVE
|
# DO NOT REMOVE ABOVE
|
||||||
|
|
||||||
@ -60,188 +35,12 @@ if hasattr(time, "tzset"):
|
|||||||
time.tzset()
|
time.tzset()
|
||||||
|
|
||||||
|
|
||||||
class DifyApp(Flask):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# -------------
|
# -------------
|
||||||
# Configuration
|
# Configuration
|
||||||
# -------------
|
# -------------
|
||||||
|
|
||||||
|
|
||||||
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# Application Factory Function
|
|
||||||
# ----------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def create_flask_app_with_configs() -> Flask:
|
|
||||||
"""
|
|
||||||
create a raw flask app
|
|
||||||
with configs loaded from .env file
|
|
||||||
"""
|
|
||||||
dify_app = DifyApp(__name__)
|
|
||||||
dify_app.config.from_mapping(dify_config.model_dump())
|
|
||||||
|
|
||||||
# populate configs into system environment variables
|
|
||||||
for key, value in dify_app.config.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
os.environ[key] = value
|
|
||||||
elif isinstance(value, int | float | bool):
|
|
||||||
os.environ[key] = str(value)
|
|
||||||
elif value is None:
|
|
||||||
os.environ[key] = ""
|
|
||||||
|
|
||||||
return dify_app
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> Flask:
|
|
||||||
app = create_flask_app_with_configs()
|
|
||||||
|
|
||||||
app.secret_key = app.config["SECRET_KEY"]
|
|
||||||
|
|
||||||
log_handlers = None
|
|
||||||
log_file = app.config.get("LOG_FILE")
|
|
||||||
if log_file:
|
|
||||||
log_dir = os.path.dirname(log_file)
|
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
|
||||||
log_handlers = [
|
|
||||||
RotatingFileHandler(
|
|
||||||
filename=log_file,
|
|
||||||
maxBytes=1024 * 1024 * 1024,
|
|
||||||
backupCount=5,
|
|
||||||
),
|
|
||||||
logging.StreamHandler(sys.stdout),
|
|
||||||
]
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=app.config.get("LOG_LEVEL"),
|
|
||||||
format=app.config.get("LOG_FORMAT"),
|
|
||||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
|
||||||
handlers=log_handlers,
|
|
||||||
force=True,
|
|
||||||
)
|
|
||||||
log_tz = app.config.get("LOG_TZ")
|
|
||||||
if log_tz:
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import pytz
|
|
||||||
|
|
||||||
timezone = pytz.timezone(log_tz)
|
|
||||||
|
|
||||||
def time_converter(seconds):
|
|
||||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
|
||||||
handler.formatter.converter = time_converter
|
|
||||||
initialize_extensions(app)
|
|
||||||
register_blueprints(app)
|
|
||||||
register_commands(app)
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_extensions(app):
|
|
||||||
# Since the application instance is now created, pass it to each Flask
|
|
||||||
# extension instance to bind it to the Flask application instance (app)
|
|
||||||
ext_compress.init_app(app)
|
|
||||||
ext_code_based_extension.init()
|
|
||||||
ext_database.init_app(app)
|
|
||||||
ext_migrate.init(app, db)
|
|
||||||
ext_redis.init_app(app)
|
|
||||||
ext_storage.init_app(app)
|
|
||||||
ext_celery.init_app(app)
|
|
||||||
ext_login.init_app(app)
|
|
||||||
ext_mail.init_app(app)
|
|
||||||
ext_hosting_provider.init_app(app)
|
|
||||||
ext_sentry.init_app(app)
|
|
||||||
ext_proxy_fix.init_app(app)
|
|
||||||
|
|
||||||
|
|
||||||
# Flask-Login configuration
|
|
||||||
@login_manager.request_loader
|
|
||||||
def load_user_from_request(request_from_flask_login):
|
|
||||||
"""Load user based on the request."""
|
|
||||||
if request.blueprint not in {"console", "inner_api"}:
|
|
||||||
return None
|
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if not auth_header:
|
|
||||||
auth_token = request.args.get("_token")
|
|
||||||
if not auth_token:
|
|
||||||
raise Unauthorized("Invalid Authorization token.")
|
|
||||||
else:
|
|
||||||
if " " not in auth_header:
|
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
|
||||||
auth_scheme = auth_scheme.lower()
|
|
||||||
if auth_scheme != "bearer":
|
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
||||||
|
|
||||||
decoded = PassportService().verify(auth_token)
|
|
||||||
user_id = decoded.get("user_id")
|
|
||||||
|
|
||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
|
||||||
if logged_in_account:
|
|
||||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
|
||||||
return logged_in_account
|
|
||||||
|
|
||||||
|
|
||||||
@login_manager.unauthorized_handler
|
|
||||||
def unauthorized_handler():
|
|
||||||
"""Handle unauthorized requests."""
|
|
||||||
return Response(
|
|
||||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
|
||||||
status=401,
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# register blueprint routers
|
|
||||||
def register_blueprints(app):
|
|
||||||
from controllers.console import bp as console_app_bp
|
|
||||||
from controllers.files import bp as files_bp
|
|
||||||
from controllers.inner_api import bp as inner_api_bp
|
|
||||||
from controllers.service_api import bp as service_api_bp
|
|
||||||
from controllers.web import bp as web_bp
|
|
||||||
|
|
||||||
CORS(
|
|
||||||
service_api_bp,
|
|
||||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
|
||||||
)
|
|
||||||
app.register_blueprint(service_api_bp)
|
|
||||||
|
|
||||||
CORS(
|
|
||||||
web_bp,
|
|
||||||
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
|
||||||
supports_credentials=True,
|
|
||||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
|
||||||
expose_headers=["X-Version", "X-Env"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.register_blueprint(web_bp)
|
|
||||||
|
|
||||||
CORS(
|
|
||||||
console_app_bp,
|
|
||||||
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
|
||||||
supports_credentials=True,
|
|
||||||
allow_headers=["Content-Type", "Authorization"],
|
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
|
||||||
expose_headers=["X-Version", "X-Env"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.register_blueprint(console_app_bp)
|
|
||||||
|
|
||||||
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
|
||||||
app.register_blueprint(files_bp)
|
|
||||||
|
|
||||||
app.register_blueprint(inner_api_bp)
|
|
||||||
|
|
||||||
|
|
||||||
# create app
|
# create app
|
||||||
app = create_app()
|
app = create_app()
|
||||||
celery = app.extensions["celery"]
|
celery = app.extensions["celery"]
|
||||||
|
213
api/app_factory.py
Normal file
213
api/app_factory.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||||
|
from gevent import monkey
|
||||||
|
|
||||||
|
monkey.patch_all()
|
||||||
|
|
||||||
|
import grpc.experimental.gevent
|
||||||
|
|
||||||
|
grpc.experimental.gevent.init_gevent()
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
from flask import Flask, Response, request
|
||||||
|
from flask_cors import CORS
|
||||||
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from commands import register_commands
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions import (
|
||||||
|
ext_celery,
|
||||||
|
ext_code_based_extension,
|
||||||
|
ext_compress,
|
||||||
|
ext_database,
|
||||||
|
ext_hosting_provider,
|
||||||
|
ext_login,
|
||||||
|
ext_mail,
|
||||||
|
ext_migrate,
|
||||||
|
ext_proxy_fix,
|
||||||
|
ext_redis,
|
||||||
|
ext_sentry,
|
||||||
|
ext_storage,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_login import login_manager
|
||||||
|
from libs.passport import PassportService
|
||||||
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
|
class DifyApp(Flask):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Application Factory Function
|
||||||
|
# ----------------------------
|
||||||
|
def create_flask_app_with_configs() -> Flask:
|
||||||
|
"""
|
||||||
|
create a raw flask app
|
||||||
|
with configs loaded from .env file
|
||||||
|
"""
|
||||||
|
dify_app = DifyApp(__name__)
|
||||||
|
dify_app.config.from_mapping(dify_config.model_dump())
|
||||||
|
|
||||||
|
# populate configs into system environment variables
|
||||||
|
for key, value in dify_app.config.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
os.environ[key] = value
|
||||||
|
elif isinstance(value, int | float | bool):
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
elif value is None:
|
||||||
|
os.environ[key] = ""
|
||||||
|
|
||||||
|
return dify_app
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> Flask:
|
||||||
|
app = create_flask_app_with_configs()
|
||||||
|
|
||||||
|
app.secret_key = app.config["SECRET_KEY"]
|
||||||
|
|
||||||
|
log_handlers = None
|
||||||
|
log_file = app.config.get("LOG_FILE")
|
||||||
|
if log_file:
|
||||||
|
log_dir = os.path.dirname(log_file)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_handlers = [
|
||||||
|
RotatingFileHandler(
|
||||||
|
filename=log_file,
|
||||||
|
maxBytes=1024 * 1024 * 1024,
|
||||||
|
backupCount=5,
|
||||||
|
),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
]
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=app.config.get("LOG_LEVEL"),
|
||||||
|
format=app.config.get("LOG_FORMAT"),
|
||||||
|
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||||
|
handlers=log_handlers,
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
log_tz = app.config.get("LOG_TZ")
|
||||||
|
if log_tz:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
timezone = pytz.timezone(log_tz)
|
||||||
|
|
||||||
|
def time_converter(seconds):
|
||||||
|
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||||
|
|
||||||
|
for handler in logging.root.handlers:
|
||||||
|
handler.formatter.converter = time_converter
|
||||||
|
initialize_extensions(app)
|
||||||
|
register_blueprints(app)
|
||||||
|
register_commands(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_extensions(app):
|
||||||
|
# Since the application instance is now created, pass it to each Flask
|
||||||
|
# extension instance to bind it to the Flask application instance (app)
|
||||||
|
ext_compress.init_app(app)
|
||||||
|
ext_code_based_extension.init()
|
||||||
|
ext_database.init_app(app)
|
||||||
|
ext_migrate.init(app, db)
|
||||||
|
ext_redis.init_app(app)
|
||||||
|
ext_storage.init_app(app)
|
||||||
|
ext_celery.init_app(app)
|
||||||
|
ext_login.init_app(app)
|
||||||
|
ext_mail.init_app(app)
|
||||||
|
ext_hosting_provider.init_app(app)
|
||||||
|
ext_sentry.init_app(app)
|
||||||
|
ext_proxy_fix.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
|
# Flask-Login configuration
|
||||||
|
@login_manager.request_loader
|
||||||
|
def load_user_from_request(request_from_flask_login):
|
||||||
|
"""Load user based on the request."""
|
||||||
|
if request.blueprint not in {"console", "inner_api"}:
|
||||||
|
return None
|
||||||
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header:
|
||||||
|
auth_token = request.args.get("_token")
|
||||||
|
if not auth_token:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
else:
|
||||||
|
if " " not in auth_header:
|
||||||
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
|
auth_scheme = auth_scheme.lower()
|
||||||
|
if auth_scheme != "bearer":
|
||||||
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
user_id = decoded.get("user_id")
|
||||||
|
|
||||||
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
|
if logged_in_account:
|
||||||
|
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||||
|
return logged_in_account
|
||||||
|
|
||||||
|
|
||||||
|
@login_manager.unauthorized_handler
|
||||||
|
def unauthorized_handler():
|
||||||
|
"""Handle unauthorized requests."""
|
||||||
|
return Response(
|
||||||
|
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||||
|
status=401,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# register blueprint routers
|
||||||
|
def register_blueprints(app):
|
||||||
|
from controllers.console import bp as console_app_bp
|
||||||
|
from controllers.files import bp as files_bp
|
||||||
|
from controllers.inner_api import bp as inner_api_bp
|
||||||
|
from controllers.service_api import bp as service_api_bp
|
||||||
|
from controllers.web import bp as web_bp
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
service_api_bp,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
)
|
||||||
|
app.register_blueprint(service_api_bp)
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
web_bp,
|
||||||
|
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
expose_headers=["X-Version", "X-Env"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(web_bp)
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
console_app_bp,
|
||||||
|
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=["Content-Type", "Authorization"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
expose_headers=["X-Version", "X-Env"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(console_app_bp)
|
||||||
|
|
||||||
|
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||||
|
app.register_blueprint(files_bp)
|
||||||
|
|
||||||
|
app.register_blueprint(inner_api_bp)
|
@ -259,6 +259,25 @@ def migrate_knowledge_vector_database():
|
|||||||
skipped_count = 0
|
skipped_count = 0
|
||||||
total_count = 0
|
total_count = 0
|
||||||
vector_type = dify_config.VECTOR_STORE
|
vector_type = dify_config.VECTOR_STORE
|
||||||
|
upper_colletion_vector_types = {
|
||||||
|
VectorType.MILVUS,
|
||||||
|
VectorType.PGVECTOR,
|
||||||
|
VectorType.RELYT,
|
||||||
|
VectorType.WEAVIATE,
|
||||||
|
VectorType.ORACLE,
|
||||||
|
VectorType.ELASTICSEARCH,
|
||||||
|
}
|
||||||
|
lower_colletion_vector_types = {
|
||||||
|
VectorType.ANALYTICDB,
|
||||||
|
VectorType.CHROMA,
|
||||||
|
VectorType.MYSCALE,
|
||||||
|
VectorType.PGVECTO_RS,
|
||||||
|
VectorType.TIDB_VECTOR,
|
||||||
|
VectorType.OPENSEARCH,
|
||||||
|
VectorType.TENCENT,
|
||||||
|
VectorType.BAIDU,
|
||||||
|
VectorType.VIKINGDB,
|
||||||
|
}
|
||||||
page = 1
|
page = 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -284,11 +303,9 @@ def migrate_knowledge_vector_database():
|
|||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
continue
|
continue
|
||||||
collection_name = ""
|
collection_name = ""
|
||||||
if vector_type == VectorType.WEAVIATE:
|
dataset_id = dataset.id
|
||||||
dataset_id = dataset.id
|
if vector_type in upper_colletion_vector_types:
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.QDRANT:
|
elif vector_type == VectorType.QDRANT:
|
||||||
if dataset.collection_binding_id:
|
if dataset.collection_binding_id:
|
||||||
dataset_collection_binding = (
|
dataset_collection_binding = (
|
||||||
@ -301,63 +318,15 @@ def migrate_knowledge_vector_database():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Dataset Collection Binding not found")
|
raise ValueError("Dataset Collection Binding not found")
|
||||||
else:
|
else:
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
|
|
||||||
elif vector_type == VectorType.MILVUS:
|
elif vector_type in lower_colletion_vector_types:
|
||||||
dataset_id = dataset.id
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.RELYT:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.TENCENT:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.PGVECTOR:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.OPENSEARCH:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": VectorType.OPENSEARCH,
|
|
||||||
"vector_store": {"class_prefix": collection_name},
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.ANALYTICDB:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": VectorType.ANALYTICDB,
|
|
||||||
"vector_store": {"class_prefix": collection_name},
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.ELASTICSEARCH:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
elif vector_type == VectorType.BAIDU:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": VectorType.BAIDU,
|
|
||||||
"vector_store": {"class_prefix": collection_name},
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
|
|
||||||
|
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||||
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
click.echo(f"Migrating dataset {dataset.id}.")
|
click.echo(f"Migrating dataset {dataset.id}.")
|
||||||
|
|
||||||
|
@ -506,11 +506,16 @@ class DataSetConfig(BaseSettings):
|
|||||||
Configuration for dataset management
|
Configuration for dataset management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
PLAN_SANDBOX_CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||||
description="Interval in days for dataset cleanup operations",
|
description="Interval in days for dataset cleanup operations - plan: sandbox",
|
||||||
default=30,
|
default=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||||
|
description="Interval in days for dataset cleanup operations - plan: pro and team",
|
||||||
|
default=7,
|
||||||
|
)
|
||||||
|
|
||||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||||
description="Enable or disable dataset operator functionality",
|
description="Enable or disable dataset operator functionality",
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
ORACLE_PORT: PositiveInt = Field(
|
||||||
description="Port number on which the Oracle database server is listening (default is 1521)",
|
description="Port number on which the Oracle database server is listening (default is 1521)",
|
||||||
default=1521,
|
default=1521,
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
PGVECTOR_PORT: PositiveInt = Field(
|
||||||
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
||||||
default=5433,
|
default=5433,
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
PGVECTO_RS_PORT: PositiveInt = Field(
|
||||||
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
||||||
default=5431,
|
default=5431,
|
||||||
)
|
)
|
||||||
|
@ -11,27 +11,39 @@ class VikingDBConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||||
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
|
description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||||
|
"Refer to the following documentation for details on obtaining credentials:"
|
||||||
|
"https://www.volcengine.com/docs/6291/65568",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||||
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
|
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
VIKINGDB_REGION: Optional[str] = Field(
|
|
||||||
default="cn-shanghai",
|
VIKINGDB_REGION: str = Field(
|
||||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||||
|
default="cn-shanghai",
|
||||||
)
|
)
|
||||||
VIKINGDB_HOST: Optional[str] = Field(
|
|
||||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
VIKINGDB_HOST: str = Field(
|
||||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||||
|
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||||
)
|
)
|
||||||
VIKINGDB_SCHEME: Optional[str] = Field(
|
|
||||||
default="http",
|
VIKINGDB_SCHEME: str = Field(
|
||||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||||
|
default="http",
|
||||||
)
|
)
|
||||||
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
|
|
||||||
default=30, description="The connection timeout of the Volcengine VikingDB service."
|
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
|
||||||
|
description="The connection timeout of the Volcengine VikingDB service.",
|
||||||
|
default=30,
|
||||||
)
|
)
|
||||||
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
|
|
||||||
default=30, description="The socket timeout of the Volcengine VikingDB service."
|
VIKINGDB_SOCKET_TIMEOUT: int = Field(
|
||||||
|
description="The socket timeout of the Volcengine VikingDB service.",
|
||||||
|
default=30,
|
||||||
)
|
)
|
||||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||||||
|
|
||||||
CURRENT_VERSION: str = Field(
|
CURRENT_VERSION: str = Field(
|
||||||
description="Dify version",
|
description="Dify version",
|
||||||
default="0.9.1",
|
default="0.9.2",
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMIT_SHA: str = Field(
|
COMMIT_SHA: str = Field(
|
||||||
|
@ -1,88 +1,24 @@
|
|||||||
import logging
|
from flask_restful import Resource
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restful import Resource, marshal, reqparse
|
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|
||||||
|
|
||||||
import services
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import (
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||||
CompletionRequestError,
|
|
||||||
ProviderModelCurrentlyNotSupportError,
|
|
||||||
ProviderNotInitializeError,
|
|
||||||
ProviderQuotaExceededError,
|
|
||||||
)
|
|
||||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.errors.error import (
|
|
||||||
LLMBadRequestError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
|
||||||
ProviderTokenNotInitError,
|
|
||||||
QuotaExceededError,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
|
||||||
from fields.hit_testing_fields import hit_testing_record_fields
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.dataset_service import DatasetService
|
|
||||||
from services.hit_testing_service import HitTestingService
|
|
||||||
|
|
||||||
|
|
||||||
class HitTestingApi(Resource):
|
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
args = self.parse_args()
|
||||||
raise NotFound("Dataset not found.")
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
try:
|
return self.perform_hit_testing(dataset, args)
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
|
||||||
except services.errors.account.NoPermissionError as e:
|
|
||||||
raise Forbidden(str(e))
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("query", type=str, location="json")
|
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = HitTestingService.retrieve(
|
|
||||||
dataset=dataset,
|
|
||||||
query=args["query"],
|
|
||||||
account=current_user,
|
|
||||||
retrieval_model=args["retrieval_model"],
|
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
|
||||||
except services.errors.index.IndexNotInitializedError:
|
|
||||||
raise DatasetNotInitializedError()
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ProviderNotInitializeError(ex.description)
|
|
||||||
except QuotaExceededError:
|
|
||||||
raise ProviderQuotaExceededError()
|
|
||||||
except ModelCurrentlyNotSupportError:
|
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
|
||||||
"in the Settings -> Model Provider."
|
|
||||||
)
|
|
||||||
except InvokeError as e:
|
|
||||||
raise CompletionRequestError(e.description)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("Hit testing failed.")
|
|
||||||
raise InternalServerError(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||||
|
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import marshal, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
import services.dataset_service
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
CompletionRequestError,
|
||||||
|
ProviderModelCurrentlyNotSupportError,
|
||||||
|
ProviderNotInitializeError,
|
||||||
|
ProviderQuotaExceededError,
|
||||||
|
)
|
||||||
|
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||||
|
from core.errors.error import (
|
||||||
|
LLMBadRequestError,
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsHitTestingBase:
|
||||||
|
@staticmethod
|
||||||
|
def get_and_validate_dataset(dataset_id: str):
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
except services.errors.account.NoPermissionError as e:
|
||||||
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hit_testing_args_check(args):
|
||||||
|
HitTestingService.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_args():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
|
parser.add_argument("query", type=str, location="json")
|
||||||
|
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
|
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def perform_hit_testing(dataset, args):
|
||||||
|
try:
|
||||||
|
response = HitTestingService.retrieve(
|
||||||
|
dataset=dataset,
|
||||||
|
query=args["query"],
|
||||||
|
account=current_user,
|
||||||
|
retrieval_model=args["retrieval_model"],
|
||||||
|
external_retrieval_model=args["external_retrieval_model"],
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||||
|
except services.errors.index.IndexNotInitializedError:
|
||||||
|
raise DatasetNotInitializedError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||||
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Hit testing failed.")
|
||||||
|
raise InternalServerError(str(e))
|
@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
|
|||||||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
|
||||||
from . import index
|
from . import index
|
||||||
from .app import app, audio, completion, conversation, file, message, workflow
|
from .app import app, audio, completion, conversation, file, message, workflow
|
||||||
from .dataset import dataset, document, segment
|
from .dataset import dataset, document, hit_testing, segment
|
||||||
|
@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse
|
|||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from constants import UUID_NIL
|
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
@ -108,7 +107,6 @@ class ChatApi(Resource):
|
|||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||||
|
from controllers.service_api import api
|
||||||
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
|
||||||
|
|
||||||
|
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||||
|
def post(self, tenant_id, dataset_id):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
|
args = self.parse_args()
|
||||||
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
return self.perform_hit_testing(dataset, args)
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
@ -62,6 +62,8 @@ class CotAgentOutputParser:
|
|||||||
thought_str = "thought:"
|
thought_str = "thought:"
|
||||||
thought_idx = 0
|
thought_idx = 0
|
||||||
|
|
||||||
|
last_character = ""
|
||||||
|
|
||||||
for response in llm_response:
|
for response in llm_response:
|
||||||
if response.delta.usage:
|
if response.delta.usage:
|
||||||
usage_dict["usage"] = response.delta.usage
|
usage_dict["usage"] = response.delta.usage
|
||||||
@ -74,35 +76,38 @@ class CotAgentOutputParser:
|
|||||||
while index < len(response):
|
while index < len(response):
|
||||||
steps = 1
|
steps = 1
|
||||||
delta = response[index : index + steps]
|
delta = response[index : index + steps]
|
||||||
last_character = response[index - 1] if index > 0 else ""
|
yield_delta = False
|
||||||
|
|
||||||
if delta == "`":
|
if delta == "`":
|
||||||
|
last_character = delta
|
||||||
code_block_cache += delta
|
code_block_cache += delta
|
||||||
code_block_delimiter_count += 1
|
code_block_delimiter_count += 1
|
||||||
else:
|
else:
|
||||||
if not in_code_block:
|
if not in_code_block:
|
||||||
if code_block_delimiter_count > 0:
|
if code_block_delimiter_count > 0:
|
||||||
|
last_character = delta
|
||||||
yield code_block_cache
|
yield code_block_cache
|
||||||
code_block_cache = ""
|
code_block_cache = ""
|
||||||
else:
|
else:
|
||||||
|
last_character = delta
|
||||||
code_block_cache += delta
|
code_block_cache += delta
|
||||||
code_block_delimiter_count = 0
|
code_block_delimiter_count = 0
|
||||||
|
|
||||||
if not in_code_block and not in_json:
|
if not in_code_block and not in_json:
|
||||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||||
if last_character not in {"\n", " ", ""}:
|
if last_character not in {"\n", " ", ""}:
|
||||||
|
yield_delta = True
|
||||||
|
else:
|
||||||
|
last_character = delta
|
||||||
|
action_cache += delta
|
||||||
|
action_idx += 1
|
||||||
|
if action_idx == len(action_str):
|
||||||
|
action_cache = ""
|
||||||
|
action_idx = 0
|
||||||
index += steps
|
index += steps
|
||||||
yield delta
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
action_cache += delta
|
|
||||||
action_idx += 1
|
|
||||||
if action_idx == len(action_str):
|
|
||||||
action_cache = ""
|
|
||||||
action_idx = 0
|
|
||||||
index += steps
|
|
||||||
continue
|
|
||||||
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
||||||
|
last_character = delta
|
||||||
action_cache += delta
|
action_cache += delta
|
||||||
action_idx += 1
|
action_idx += 1
|
||||||
if action_idx == len(action_str):
|
if action_idx == len(action_str):
|
||||||
@ -112,24 +117,25 @@ class CotAgentOutputParser:
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if action_cache:
|
if action_cache:
|
||||||
|
last_character = delta
|
||||||
yield action_cache
|
yield action_cache
|
||||||
action_cache = ""
|
action_cache = ""
|
||||||
action_idx = 0
|
action_idx = 0
|
||||||
|
|
||||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||||
if last_character not in {"\n", " ", ""}:
|
if last_character not in {"\n", " ", ""}:
|
||||||
|
yield_delta = True
|
||||||
|
else:
|
||||||
|
last_character = delta
|
||||||
|
thought_cache += delta
|
||||||
|
thought_idx += 1
|
||||||
|
if thought_idx == len(thought_str):
|
||||||
|
thought_cache = ""
|
||||||
|
thought_idx = 0
|
||||||
index += steps
|
index += steps
|
||||||
yield delta
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
thought_cache += delta
|
|
||||||
thought_idx += 1
|
|
||||||
if thought_idx == len(thought_str):
|
|
||||||
thought_cache = ""
|
|
||||||
thought_idx = 0
|
|
||||||
index += steps
|
|
||||||
continue
|
|
||||||
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
||||||
|
last_character = delta
|
||||||
thought_cache += delta
|
thought_cache += delta
|
||||||
thought_idx += 1
|
thought_idx += 1
|
||||||
if thought_idx == len(thought_str):
|
if thought_idx == len(thought_str):
|
||||||
@ -139,12 +145,20 @@ class CotAgentOutputParser:
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if thought_cache:
|
if thought_cache:
|
||||||
|
last_character = delta
|
||||||
yield thought_cache
|
yield thought_cache
|
||||||
thought_cache = ""
|
thought_cache = ""
|
||||||
thought_idx = 0
|
thought_idx = 0
|
||||||
|
|
||||||
|
if yield_delta:
|
||||||
|
index += steps
|
||||||
|
last_character = delta
|
||||||
|
yield delta
|
||||||
|
continue
|
||||||
|
|
||||||
if code_block_delimiter_count == 3:
|
if code_block_delimiter_count == 3:
|
||||||
if in_code_block:
|
if in_code_block:
|
||||||
|
last_character = delta
|
||||||
yield from extra_json_from_code_block(code_block_cache)
|
yield from extra_json_from_code_block(code_block_cache)
|
||||||
code_block_cache = ""
|
code_block_cache = ""
|
||||||
|
|
||||||
@ -156,8 +170,10 @@ class CotAgentOutputParser:
|
|||||||
if delta == "{":
|
if delta == "{":
|
||||||
json_quote_count += 1
|
json_quote_count += 1
|
||||||
in_json = True
|
in_json = True
|
||||||
|
last_character = delta
|
||||||
json_cache += delta
|
json_cache += delta
|
||||||
elif delta == "}":
|
elif delta == "}":
|
||||||
|
last_character = delta
|
||||||
json_cache += delta
|
json_cache += delta
|
||||||
if json_quote_count > 0:
|
if json_quote_count > 0:
|
||||||
json_quote_count -= 1
|
json_quote_count -= 1
|
||||||
@ -168,16 +184,19 @@ class CotAgentOutputParser:
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if in_json:
|
if in_json:
|
||||||
|
last_character = delta
|
||||||
json_cache += delta
|
json_cache += delta
|
||||||
|
|
||||||
if got_json:
|
if got_json:
|
||||||
got_json = False
|
got_json = False
|
||||||
|
last_character = delta
|
||||||
yield parse_action(json_cache)
|
yield parse_action(json_cache)
|
||||||
json_cache = ""
|
json_cache = ""
|
||||||
json_quote_count = 0
|
json_quote_count = 0
|
||||||
in_json = False
|
in_json = False
|
||||||
|
|
||||||
if not in_code_block and not in_json:
|
if not in_code_block and not in_json:
|
||||||
|
last_character = delta
|
||||||
yield delta.replace("`", "")
|
yield delta.replace("`", "")
|
||||||
|
|
||||||
index += steps
|
index += steps
|
||||||
|
@ -10,6 +10,7 @@ from flask import Flask, current_app
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||||
@ -122,7 +123,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||||
@ -127,7 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
@ -128,7 +129,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
@ -2,8 +2,9 @@ from collections.abc import Mapping
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||||
|
"""
|
||||||
|
Base entity for conversation-based app generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
parent_message_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||||
|
"For service API, we need to ensure its forward compatibility, "
|
||||||
|
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||||
|
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("parent_message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||||
|
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||||
|
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
Chat Application Generate Entity.
|
Chat Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
pass
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||||
@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
Agent Chat Application Generate Entity.
|
Agent Chat Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
pass
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
Advanced Chat Application Generate Entity.
|
Advanced Chat Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||||||
# app config
|
# app config
|
||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
workflow_run_id: Optional[str] = None
|
workflow_run_id: Optional[str] = None
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
from collections.abc import Callable, Generator, Sequence
|
from collections.abc import Callable, Generator, Sequence
|
||||||
from typing import IO, Optional, Union, cast
|
from typing import IO, Optional, Union, cast
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
@ -1098,6 +1098,14 @@ LLM_BASE_MODELS = [
|
|||||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="temperature",
|
||||||
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_p",
|
||||||
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name="response_format",
|
name="response_format",
|
||||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
@ -1135,6 +1143,14 @@ LLM_BASE_MODELS = [
|
|||||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="temperature",
|
||||||
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_p",
|
||||||
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name="response_format",
|
name="response_format",
|
||||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
|
@ -119,7 +119,15 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
try:
|
try:
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||||
|
|
||||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
if model.startswith("o1"):
|
||||||
|
client.chat.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
|
model=model,
|
||||||
|
temperature=1,
|
||||||
|
max_completion_tokens=20,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||||
# chat model
|
# chat model
|
||||||
client.chat.completions.create(
|
client.chat.completions.create(
|
||||||
messages=[{"role": "user", "content": "ping"}],
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00025'
|
input: '0.00025'
|
||||||
output: '0.00125'
|
output: '0.00125'
|
||||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.015'
|
output: '0.015'
|
||||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.015'
|
output: '0.015'
|
||||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00025'
|
input: '0.00025'
|
||||||
output: '0.00125'
|
output: '0.00125'
|
||||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.015'
|
input: '0.015'
|
||||||
output: '0.075'
|
output: '0.075'
|
||||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.015'
|
output: '0.015'
|
||||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.015'
|
output: '0.015'
|
||||||
|
@ -13,7 +13,7 @@ from botocore.exceptions import (
|
|||||||
UnknownServiceError,
|
UnknownServiceError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -5,7 +5,7 @@ import cohere
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from cohere.core import RequestOptions
|
from cohere.core import RequestOptions
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -18,6 +18,7 @@ supported_model_types:
|
|||||||
- text-embedding
|
- text-embedding
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
|
- customizable-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
credential_form_schemas:
|
credential_form_schemas:
|
||||||
- variable: fireworks_api_key
|
- variable: fireworks_api_key
|
||||||
@ -28,3 +29,75 @@ provider_credential_schema:
|
|||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的 API Key
|
zh_Hans: 在此输入您的 API Key
|
||||||
en_US: Enter your API Key
|
en_US: Enter your API Key
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model URL
|
||||||
|
zh_Hans: 模型URL
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Model URL
|
||||||
|
zh_Hans: 输入模型URL
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: model_label_zh_Hanns
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型中文名称
|
||||||
|
en_US: The zh_Hans of Model
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型中文名称
|
||||||
|
en_US: Enter your zh_Hans of Model
|
||||||
|
- variable: model_label_en_US
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型英文名称
|
||||||
|
en_US: The en_US of Model
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型英文名称
|
||||||
|
en_US: Enter your en_US of Model
|
||||||
|
- variable: fireworks_api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
- variable: context_size
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型上下文长度
|
||||||
|
en_US: Model context size
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
default: '4096'
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
|
en_US: Enter your Model context size
|
||||||
|
- variable: max_tokens
|
||||||
|
label:
|
||||||
|
zh_Hans: 最大 token 上限
|
||||||
|
en_US: Upper bound for max tokens
|
||||||
|
default: '4096'
|
||||||
|
type: text-input
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- variable: function_calling_type
|
||||||
|
label:
|
||||||
|
en_US: Function calling
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: no_call
|
||||||
|
options:
|
||||||
|
- value: no_call
|
||||||
|
label:
|
||||||
|
en_US: Not Support
|
||||||
|
zh_Hans: 不支持
|
||||||
|
- value: function_call
|
||||||
|
label:
|
||||||
|
en_US: Support
|
||||||
|
zh_Hans: 支持
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
@ -43,3 +43,4 @@ pricing:
|
|||||||
output: '0.2'
|
output: '0.2'
|
||||||
unit: '0.000001'
|
unit: '0.000001'
|
||||||
currency: USD
|
currency: USD
|
||||||
|
deprecated: true
|
||||||
|
@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
|||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||||
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
|||||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
return AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=credentials.get("model_label_en_US", model),
|
||||||
|
zh_Hans=credentials.get("model_label_zh_Hanns", model),
|
||||||
|
),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||||
|
if credentials.get("function_calling_type") == "function_call"
|
||||||
|
else [],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||||
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="temperature",
|
||||||
|
use_template="temperature",
|
||||||
|
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="max_tokens",
|
||||||
|
use_template="max_tokens",
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=int(credentials.get("max_tokens", 4096)),
|
||||||
|
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_p",
|
||||||
|
use_template="top_p",
|
||||||
|
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_k",
|
||||||
|
use_template="top_k",
|
||||||
|
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
model: accounts/fireworks/models/qwen2p5-72b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Qwen2.5 72B Instruct
|
||||||
|
en_US: Qwen2.5 72B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.9'
|
||||||
|
output: '0.9'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -5,7 +5,7 @@ from typing import Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -32,15 +32,6 @@ parameter_rules:
|
|||||||
max: 8192
|
max: 8192
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -27,15 +27,6 @@ parameter_rules:
|
|||||||
default: 4096
|
default: 4096
|
||||||
min: 1
|
min: 1
|
||||||
max: 4096
|
max: 4096
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -31,15 +31,6 @@ parameter_rules:
|
|||||||
max: 2048
|
max: 2048
|
||||||
- name: response_format
|
- name: response_format
|
||||||
use_template: response_format
|
use_template: response_format
|
||||||
- name: stream
|
|
||||||
label:
|
|
||||||
zh_Hans: 流式输出
|
|
||||||
en_US: Stream
|
|
||||||
type: boolean
|
|
||||||
help:
|
|
||||||
zh_Hans: 流式输出允许模型在生成文本的过程中逐步返回结果,而不是一次性生成全部结果后再返回。
|
|
||||||
en_US: Streaming output allows the model to return results incrementally as it generates text, rather than generating all the results at once.
|
|
||||||
default: false
|
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00'
|
input: '0.00'
|
||||||
output: '0.00'
|
output: '0.00'
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import requests
|
import requests
|
||||||
from huggingface_hub import HfApi, InferenceClient
|
from huggingface_hub import HfApi, InferenceClient
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile
|
|||||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
from requests import post
|
from requests import post
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
model: abab6.5t-chat
|
||||||
|
label:
|
||||||
|
en_US: Abab6.5t-Chat
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
min: 0.01
|
||||||
|
max: 1
|
||||||
|
default: 0.9
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
min: 0.01
|
||||||
|
max: 1
|
||||||
|
default: 0.95
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 3072
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
- name: mask_sensitive_info
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
label:
|
||||||
|
zh_Hans: 隐私保护
|
||||||
|
en_US: Moderate
|
||||||
|
help:
|
||||||
|
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||||
|
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
pricing:
|
||||||
|
input: '0.005'
|
||||||
|
output: '0.005'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
@ -61,7 +61,8 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
url = f"{self.api_base}?GroupId={group_id}"
|
url = f"{self.api_base}?GroupId={group_id}"
|
||||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query"
|
||||||
|
data = {"model": "embo-01", "texts": texts, "type": embedding_type}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(url, headers=headers, data=dumps(data))
|
response = post(url, headers=headers, data=dumps(data))
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
from nomic import embed
|
from nomic import embed
|
||||||
from nomic import login as nomic_login
|
from nomic import login as nomic_login
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import (
|
from core.model_runtime.entities.text_embedding_entities import (
|
||||||
EmbeddingUsage,
|
EmbeddingUsage,
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import oci
|
import oci
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -8,7 +8,7 @@ from urllib.parse import urljoin
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
|
@ -19,9 +19,9 @@ class OpenAIProvider(ModelProvider):
|
|||||||
try:
|
try:
|
||||||
model_instance = self.get_model_instance(ModelType.LLM)
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
# Use `gpt-3.5-turbo` model for validate,
|
# Use `gpt-4o-mini` model for validate,
|
||||||
# no matter what model you pass in, text completion model or chat model
|
# no matter what model you pass in, text completion model or chat model
|
||||||
model_instance.validate_credentials(model="gpt-3.5-turbo", credentials=credentials)
|
model_instance.validate_credentials(model="gpt-4o-mini", credentials=credentials)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
@ -7,7 +7,7 @@ from urllib.parse import urljoin
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||||||
from requests import post
|
from requests import post
|
||||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.model_entities import PriceType
|
from core.model_runtime.entities.model_entities import PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
|
@ -35,6 +35,15 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
default: 0
|
default: 0
|
||||||
|
@ -18,6 +18,15 @@ parameter_rules:
|
|||||||
min: 0
|
min: 0
|
||||||
max: 1
|
max: 1
|
||||||
default: 1
|
default: 1
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
min: 1
|
min: 1
|
||||||
|
@ -14,6 +14,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -14,6 +14,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -14,6 +14,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -16,6 +16,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -15,6 +15,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -15,6 +15,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -10,6 +10,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
required: true
|
||||||
|
@ -10,6 +10,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
required: true
|
||||||
|
@ -10,6 +10,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
required: true
|
||||||
|
@ -10,6 +10,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
required: true
|
||||||
|
@ -10,6 +10,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
required: true
|
required: true
|
||||||
|
@ -18,6 +18,15 @@ parameter_rules:
|
|||||||
default: 1
|
default: 1
|
||||||
min: 0
|
min: 0
|
||||||
max: 1
|
max: 1
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
default: 1024
|
default: 1024
|
||||||
|
@ -18,6 +18,15 @@ parameter_rules:
|
|||||||
default: 1
|
default: 1
|
||||||
min: 0
|
min: 0
|
||||||
max: 1
|
max: 1
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
default: 1024
|
default: 1024
|
||||||
|
@ -19,6 +19,15 @@ parameter_rules:
|
|||||||
default: 1
|
default: 1
|
||||||
min: 0
|
min: 0
|
||||||
max: 1
|
max: 1
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: max_tokens
|
- name: max_tokens
|
||||||
use_template: max_tokens
|
use_template: max_tokens
|
||||||
default: 1024
|
default: 1024
|
||||||
|
@ -12,6 +12,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -12,6 +12,15 @@ parameter_rules:
|
|||||||
use_template: temperature
|
use_template: temperature
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: presence_penalty
|
- name: presence_penalty
|
||||||
use_template: presence_penalty
|
use_template: presence_penalty
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
|
@ -21,6 +21,15 @@ parameter_rules:
|
|||||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
pricing:
|
pricing:
|
||||||
|
@ -21,6 +21,15 @@ parameter_rules:
|
|||||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
pricing:
|
pricing:
|
||||||
|
@ -7,7 +7,7 @@ from urllib.parse import urljoin
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from replicate import Client as ReplicateClient
|
from replicate import Client as ReplicateClient
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||||
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
|
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
|
||||||
|
|
||||||
@ -77,7 +78,8 @@ class SageMakerSpeech2TextModel(Speech2TextModel):
|
|||||||
json_obj = json.loads(json_str)
|
json_obj = json.loads(json_str)
|
||||||
asr_text = json_obj["text"]
|
asr_text = json_obj["text"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Exception {e}, line : {line}")
|
logger.exception(f"failed to invoke speech2text model, {e}")
|
||||||
|
raise CredentialsValidateFailedError(str(e))
|
||||||
|
|
||||||
return asr_text
|
return asr_text
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
@ -21,6 +21,15 @@ parameter_rules:
|
|||||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||||
- name: top_p
|
- name: top_p
|
||||||
use_template: top_p
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
- name: frequency_penalty
|
- name: frequency_penalty
|
||||||
use_template: frequency_penalty
|
use_template: frequency_penalty
|
||||||
pricing:
|
pricing:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user