Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2024-08-21 13:40:04 +08:00
commit fff40aae58
366 changed files with 8078 additions and 3905 deletions

View File

@ -45,6 +45,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Ruff formatter check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff format --check ./api
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

1
.gitignore vendored
View File

@ -178,3 +178,4 @@ pyrightconfig.json
api/.vscode
.idea/
.vscode

View File

@ -268,3 +268,12 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

View File

@ -5,8 +5,8 @@
"name": "Python: Flask",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"envFile": ".env",
"module": "flask",
"justMyCode": true,
@ -18,15 +18,15 @@
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--port=5001"
]
},
{
"name": "Python: Celery",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"module": "celery",
"justMyCode": true,
"envFile": ".env",

View File

@ -1,6 +1,6 @@
import os
if os.environ.get("DEBUG", "false").lower() != 'true':
if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey
monkey.patch_all()
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
if os.name == "nt":
os.system('tzutil /s "UTC"')
else:
os.environ['TZ'] = 'UTC'
os.environ["TZ"] = "UTC"
time.tzset()
@ -70,13 +70,14 @@ class DifyApp(Flask):
# -------------
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
"""
create a raw flask app
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ''
os.environ[key] = ""
return dify_app
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config['SECRET_KEY']
app.secret_key = app.config["SECRET_KEY"]
log_handlers = None
log_file = app.config.get('LOG_FILE')
log_file = app.config.get("LOG_FILE")
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
@ -111,23 +112,24 @@ def create_app() -> Flask:
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5
backupCount=5,
),
logging.StreamHandler(sys.stdout)
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=app.config.get('LOG_LEVEL'),
format=app.config.get('LOG_FORMAT'),
datefmt=app.config.get('LOG_DATEFORMAT'),
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True
force=True,
)
log_tz = app.config.get('LOG_TZ')
log_tz = app.config.get("LOG_TZ")
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
@ -162,24 +164,24 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in ['console', 'inner_api']:
if request.blueprint not in ["console", "inner_api"]:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get('_token')
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
raise Unauthorized("Invalid Authorization token.")
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
user_id = decoded.get("user_id")
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
if account:
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(json.dumps({
'code': 'unauthorized',
'message': "Unauthorized."
}), status=401, content_type="application/json")
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
# register blueprint routers
@ -204,38 +207,36 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
CORS(web_bp,
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(web_bp)
CORS(console_app_bp,
resources={
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(console_app_bp)
CORS(files_bp,
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
@ -245,29 +246,29 @@ def register_blueprints(app):
app = create_app()
celery = app.extensions["celery"]
if app.config.get('TESTING'):
if app.config.get("TESTING"):
print("App is running in TESTING mode")
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
return response
@app.route('/health')
@app.route("/health")
def health():
return Response(json.dumps({
'pid': os.getpid(),
'status': 'ok',
'version': app.config['CURRENT_VERSION']
}), status=200, content_type="application/json")
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
status=200,
content_type="application/json",
)
@app.route('/threads')
@app.route("/threads")
def threads():
num_threads = threading.active_count()
threads = threading.enumerate()
@ -278,32 +279,34 @@ def threads():
thread_id = thread.ident
is_alive = thread.is_alive()
thread_list.append({
'name': thread_name,
'id': thread_id,
'is_alive': is_alive
})
thread_list.append(
{
"name": thread_name,
"id": thread_id,
"is_alive": is_alive,
}
)
return {
'pid': os.getpid(),
'thread_num': num_threads,
'threads': thread_list
"pid": os.getpid(),
"thread_num": num_threads,
"threads": thread_list,
}
@app.route('/db-pool-stat')
@app.route("/db-pool-stat")
def pool_stat():
engine = db.engine
return {
'pid': os.getpid(),
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
"pid": os.getpid(),
"pool_size": engine.pool.size(),
"checked_in_connections": engine.pool.checkedin(),
"checked_out_connections": engine.pool.checkedout(),
"overflow_connections": engine.pool.overflow(),
"connection_timeout": engine.pool.timeout(),
"recycle_time": db.engine.pool._recycle,
}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
@click.command('reset-password', help='Reset the account password.')
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
@click.option('--new-password', prompt=True, help='the new password.')
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option("--new-password", prompt=True, help="the new password.")
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
return
# generate password salt
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
@click.command('reset-email', help='Reset the account email.')
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
@click.option('--new-email', prompt=True, help='the new email.')
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
@click.command("reset-email", help="Reset the account email.")
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option("--new-email", prompt=True, help="the new email.")
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
'After the reset, all LLM credentials will become invalid, '
'requiring re-entry.'
'Only support SELF_HOSTED mode.')
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
' this operation cannot be rolled back!', fg='red'))
@click.command(
"reset-encrypt-key-pair",
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
"After the reset, all LLM credentials will become invalid, "
"requiring re-entry."
"Only support SELF_HOSTED mode.",
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
"""
Reset the encrypted key pair of workspace for encrypt LLM credentials.
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if dify_config.EDITION != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(click.style('Congratulations! '
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
click.echo(
click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@click.command('vdb-migrate', help='migrate vector db.')
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']:
if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database()
if scope in ['annotation', 'all']:
if scope in ["annotation", "all"]:
migrate_annotation_vector_database()
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style('Start migrate annotation data.', fg='green'))
click.echo(click.style("Start migrate annotation data.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
while True:
try:
# get apps info
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
apps = (
db.session.query(App)
.filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
click.echo(
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id))
click.echo("App annotation setting is disabled: {}".format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
click.echo("App annotation collection binding is not exist: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
collection_binding_id=dataset_collection_binding.id,
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
raise e
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
click.style(
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
click.echo(f'Successfully migrated app annotation {app.id}.')
click.echo(f"Successfully migrated app annotation {app.id}.")
create_count += 1
except Exception as e:
click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style(
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
continue
click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
click.style(
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start migrate vector db.', fg='green'))
click.echo(click.style("Start migrate vector db.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
datasets = (
db.session.query(Dataset)
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
page += 1
for dataset in datasets:
total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
click.echo(
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
click.echo("Create dataset vdb index: {}".format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ''
collection_name = ""
if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
raise ValueError("Dataset Collection Bindings is not exist!")
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.RELYT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'relyt',
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name}
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
@ -341,16 +340,13 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name}
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'elasticsearch',
"vector_store": {"class_prefix": index_name}
}
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")
@ -361,29 +357,41 @@ def migrate_knowledge_vector_database():
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
click.style(
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
click.style(
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
)
)
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
for segment in segments:
document = Document(
@ -393,7 +401,7 @@ def migrate_knowledge_vector_database():
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
},
)
documents.append(document)
@ -401,37 +409,43 @@ def migrate_knowledge_vector_database():
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
click.echo(
click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e
db.session.add(dataset)
db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.')
click.echo(f"Successfully migrated dataset {dataset.id}.")
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
continue
click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
click.style(
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style('Start convert to agent apps.', fg='green'))
click.echo(click.style("Start convert to agent apps.", fg="green"))
proceeded_app_ids = []
@ -466,7 +480,7 @@ def convert_to_agent_apps():
break
for app in apps:
click.echo('Converting app: {}'.format(app.id))
click.echo("Converting app: {}".format(app.id))
try:
app.mode = AppMode.AGENT_CHAT.value
@ -478,137 +492,139 @@ def convert_to_agent_apps():
)
db.session.commit()
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
except Exception as e:
click.echo(
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
str(e)), fg='red'))
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
def add_qdrant_doc_id_index(field: str):
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant":
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError('Qdrant url is required.')
raise ValueError("Qdrant url is required.")
qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
)
try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# create payload index
client.create_payload_index(binding.collection_name, field,
field_schema=PayloadSchemaType.KEYWORD)
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
create_count += 1
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
click.echo(
click.style(
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
except Exception as e:
click.echo(click.style('Failed to create qdrant client.', fg='red'))
click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo(
click.style(f'Congratulations! Create {create_count} collection indexes.',
fg='green'))
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
@click.command('create-tenant', help='Create account and tenant.')
@click.option('--email', prompt=True, help='The email address of the tenant account.')
@click.option('--language', prompt=True, help='Account language, default: en-US.')
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style('Sorry, email is required.', fg='red'))
click.echo(click.style("Sorry, email is required.", fg="red"))
return
# Create account
email = email.strip()
if '@' not in email:
click.echo(click.style('Sorry, invalid email address.', fg='red'))
if "@" not in email:
click.echo(click.style("Sorry, invalid email address.", fg="red"))
return
account_name = email.split('@')[0]
account_name = email.split("@")[0]
if language not in languages:
language = 'en-US'
language = "en-US"
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language
)
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(click.style('Congratulations! Account and tenant created.\n'
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
click.echo(
click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
@click.command('upgrade-db', help='upgrade the database')
@click.command("upgrade-db", help="upgrade the database")
def upgrade_db():
click.echo('Preparing database migration...')
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style('Start database migration.', fg='green'))
click.echo(click.style("Start database migration.", fg="green"))
# run db migration
import flask_migrate
flask_migrate.upgrade()
click.echo(click.style('Database migration successful!', fg='green'))
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
logging.exception(f'Database migration failed, error: {e}')
logging.exception(f"Database migration failed, error: {e}")
finally:
lock.release()
else:
click.echo('Database migration skipped')
click.echo("Database migration skipped")
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
def fix_app_site_missing():
"""
Fix app related site missing issue.
"""
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
failed_app_ids = []
while True:
@ -639,15 +655,14 @@ where sites.id is null limit 1000"""
app_was_created.send(app, account=account)
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
logging.exception(f'Fix app related site missing issue failed, error: {e}')
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
continue
if not processed_count:
break
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
def register_commands(app):

View File

@ -37,6 +37,7 @@ class DifyConfig(
CODE_MAX_NUMBER: int = 9223372036854775807
CODE_MIN_NUMBER: int = -9223372036854775808
CODE_MAX_DEPTH: int = 5
CODE_MAX_STRING_LENGTH: int = 80000
CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30

View File

@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings):
default=False,
)
class WorkspaceConfig(BaseSettings):
"""
Workspace configs
@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings):
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description='The heads of model providers',
default='',
)
POSITION_PROVIDER_INCLUDES: str = Field(
description='The included model providers',
default='',
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description='The excluded model providers',
default='',
)
POSITION_TOOL_PINS: str = Field(
description='The heads of tools',
default='',
)
POSITION_TOOL_INCLUDES: str = Field(
description='The included tools',
default='',
)
POSITION_TOOL_EXCLUDES: str = Field(
description='The excluded tools',
default='',
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != '']
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''}
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -466,6 +524,7 @@ class FeatureConfig(
UpdateConfig,
WorkflowConfig,
WorkspaceConfig,
PositionConfig,
# hosted services config
HostedServiceConfig,

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.7.0',
default='0.7.1',
)
COMMIT_SHA: str = Field(

View File

@ -1 +1 @@
HIDDEN_VALUE = '[__HIDDEN__]'
HIDDEN_VALUE = "[__HIDDEN__]"

View File

@ -1,22 +1,22 @@
language_timezone_mapping = {
'en-US': 'America/New_York',
'zh-Hans': 'Asia/Shanghai',
'zh-Hant': 'Asia/Taipei',
'pt-BR': 'America/Sao_Paulo',
'es-ES': 'Europe/Madrid',
'fr-FR': 'Europe/Paris',
'de-DE': 'Europe/Berlin',
'ja-JP': 'Asia/Tokyo',
'ko-KR': 'Asia/Seoul',
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
'vi-VN': 'Asia/Ho_Chi_Minh',
'ro-RO': 'Europe/Bucharest',
'pl-PL': 'Europe/Warsaw',
'hi-IN': 'Asia/Kolkata',
'tr-TR': 'Europe/Istanbul',
'fa-IR': 'Asia/Tehran',
"en-US": "America/New_York",
"zh-Hans": "Asia/Shanghai",
"zh-Hant": "Asia/Taipei",
"pt-BR": "America/Sao_Paulo",
"es-ES": "Europe/Madrid",
"fr-FR": "Europe/Paris",
"de-DE": "Europe/Berlin",
"ja-JP": "Asia/Tokyo",
"ko-KR": "Asia/Seoul",
"ru-RU": "Europe/Moscow",
"it-IT": "Europe/Rome",
"uk-UA": "Europe/Kyiv",
"vi-VN": "Asia/Ho_Chi_Minh",
"ro-RO": "Europe/Bucharest",
"pl-PL": "Europe/Warsaw",
"hi-IN": "Asia/Kolkata",
"tr-TR": "Europe/Istanbul",
"fa-IR": "Asia/Tehran",
}
languages = list(language_timezone_mapping.keys())
@ -26,6 +26,5 @@ def supported_language(lang):
if lang in languages:
return lang
error = ('{lang} is not a valid language.'
.format(lang=lang))
error = "{lang} is not a valid language.".format(lang=lang)
raise ValueError(error)

View File

@ -5,82 +5,79 @@ from models.model import AppMode
default_app_templates = {
# workflow default mode
AppMode.WORKFLOW: {
'app': {
'mode': AppMode.WORKFLOW.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.WORKFLOW.value,
"enable_site": True,
"enable_api": True,
}
},
# completion default mode
AppMode.COMPLETION: {
'app': {
'mode': AppMode.COMPLETION.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.COMPLETION.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
"completion_params": {},
},
'user_input_form': json.dumps([
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
"user_input_form": json.dumps(
[
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": "",
},
},
]
),
"pre_prompt": "{{query}}",
},
},
# chat default mode
AppMode.CHAT: {
'app': {
'mode': AppMode.CHAT.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.CHAT.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
}
}
"completion_params": {},
},
},
},
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
'app': {
'mode': AppMode.ADVANCED_CHAT.value,
'enable_site': True,
'enable_api': True
}
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"enable_site": True,
"enable_api": True,
},
},
# agent-chat default mode
AppMode.AGENT_CHAT: {
'app': {
'mode': AppMode.AGENT_CHAT.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.AGENT_CHAT.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
}
}
}
"completion_params": {},
},
},
},
}

View File

@ -2,6 +2,6 @@ from contextvars import ContextVar
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar('tenant_id')
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")

View File

@ -61,6 +61,7 @@ class AppListApi(Resource):
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
@ -94,6 +95,7 @@ class AppImportApi(Resource):
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
@ -167,6 +169,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
parser.add_argument('max_active_requests', type=int, location='json')
@ -208,6 +211,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon_type', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()

View File

@ -33,7 +33,7 @@ class CompletionConversationApi(Resource):
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
def get(self, app_model):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
def get(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
@ -119,7 +119,7 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def delete(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
@ -154,6 +154,8 @@ class ChatConversationApi(Resource):
parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
subquery = (
@ -225,7 +227,17 @@ class ChatConversationApi(Resource):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.order_by(Conversation.created_at.desc())
match args['sort_by']:
case 'created_at':
query = query.order_by(Conversation.created_at.asc())
case '-created_at':
query = query.order_by(Conversation.created_at.desc())
case 'updated_at':
query = query.order_by(Conversation.updated_at.asc())
case '-updated_at':
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(
query,
@ -256,7 +268,7 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
def delete(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)

View File

@ -16,6 +16,7 @@ from models.model import Site
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument('description', type=str, required=False, location='json')
@ -53,6 +54,7 @@ class AppSite(Resource):
for attr_name in [
'title',
'icon_type',
'icon',
'icon_background',
'description',

View File

@ -459,6 +459,7 @@ class ConvertToWorkflowApi(Resource):
if request.data:
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()

View File

@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -25,6 +25,8 @@ class ConversationApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
try:
@ -33,7 +35,8 @@ class ConversationApi(Resource):
user=end_user,
last_id=args['last_id'],
limit=args['limit'],
invoke_from=InvokeFrom.SERVICE_API
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args['sort_by']
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -59,13 +59,16 @@ class SegmentApi(DatasetApiResource):
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
if args['segments'] is not None:
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
else:
return {"error": "Segemtns is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource):
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'],
required=False, default='-updated_at', location='args')
args = parser.parse_args()
pinned = None
@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource):
limit=args['limit'],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args['sort_by']
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -6,6 +6,7 @@ from configs import dify_config
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from libs.helper import AppIconUrlField
from models.account import TenantStatus
from models.model import Site
from services.feature_service import FeatureService
@ -28,8 +29,10 @@ class AppSiteApi(WebApiResource):
'title': fields.String,
'chat_color_theme': fields.String,
'chat_color_theme_inverted': fields.Boolean,
'icon_type': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'icon_url': AppIconUrlField,
'description': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,

View File

@ -64,15 +64,19 @@ class BaseAgentRunner(AppRunner):
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
@ -445,7 +449,7 @@ class BaseAgentRunner(AppRunner):
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
tool_responses = { tool: agent_thought.observation for tool in tools }
tool_responses = dict.fromkeys(tools, agent_thought.observation)
for tool in tools:
# generate a uuid for tool call

View File

@ -292,6 +292,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool

View File

@ -1,6 +1,6 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
:param config: model config args
"""
external_data_variables = []
variables = []
variable_entities = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
)
# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
if 'config' not in val:
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
variables.append(
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)
return variables, external_data_variables
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:

View File

@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str
label: str
description: Optional[str] = None
type: Type
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel):
"""

View File

@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
):
"""
Generate App response.
@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files
files = args['files'] if args.get('files') else []
@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
"""
Generate App response.
@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(
@ -249,8 +249,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
for message in self._queue_manager.listen():
if (message.event
and hasattr(message.event, 'metadata')
and message.event.metadata
and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator:
@ -9,29 +9,29 @@ class BaseAppGenerator:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
@ -39,14 +39,14 @@ class BaseAppGenerator:
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
@ -126,7 +128,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
@ -254,6 +256,7 @@ class AppRunner:
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
:return:
"""
if not stream:
@ -276,6 +279,7 @@ class AppRunner:
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
queue_manager.publish(
@ -291,6 +295,7 @@ class AppRunner:
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:return:
"""
model = None

View File

@ -1,6 +1,7 @@
import json
import logging
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from sqlalchemy import and_
@ -36,17 +37,17 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response(
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
@ -138,6 +139,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
"""
Initialize generate records
:param application_generate_entity: application generate entity
:conversation conversation
:return:
"""
app_config = application_generate_entity.app_config
@ -192,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
message = Message(
app_id=app_config.app_id,

View File

@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -67,8 +67,8 @@ class WorkflowAppRunner:
# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,

View File

@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState(

View File

@ -99,7 +99,13 @@ class ObjectSegment(Segment):
class ArraySegment(Segment):
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
items = []
for item in self.value:
if hasattr(item, 'to_markdown'):
items.append(item.to_markdown())
else:
items.append(str(item))
return '\n'.join(items)
class ArrayAnySegment(ArraySegment):

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -99,7 +99,7 @@ class MessageFileParser:
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
"""
transform message files
@ -144,7 +144,7 @@ class MessageFileParser:
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
"""
transform file to file obj

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from collections.abc import Callable
from typing import Any
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file
@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
return {name: index for index, name in enumerate(positions)}
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for tools from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
)
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
"""
Get the mapping for providers from name to index from a YAML file.
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
position_map = get_position_map(folder_path, file_name=file_name)
return pin_position_map(
position_map,
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
)
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
"""
Pin the items in the pin list to the beginning of the position map.
Overall logic: exclude > include > pin
:param position_map: the position map to be sorted and filtered
:param pin_list: the list of pins to be put at the beginning
:return: the sorted position map
"""
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
# Add pins to position map
position_map = {name: idx for idx, name in enumerate(pin_list)}
# Add remaining positions to position map
start_idx = len(position_map)
for name in positions:
if name not in position_map:
position_map[name] = start_idx
start_idx += 1
return position_map
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: Any,
name_func: Callable[[Any], str],
) -> bool:
"""
Chcek if the object should be filtered out.
Overall logic: exclude > include > pin
:param include_set: the set of names to be included
:param exclude_set: the set of names to be excluded
:param name_func: the function to get the name of the object
:param data: the data to be filtered
:return: True if the object should be filtered out, False otherwise
"""
if not data:
return False
if not include_set and not exclude_set:
return False
name = name_func(data)
if name in exclude_set: # exclude_set is prioritized
return True
if include_set and name not in include_set: # filter out only if include_set is not empty
return True
return False
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],

View File

@ -700,6 +700,7 @@ class IndexingRunner:
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None,
}
)

View File

@ -271,9 +271,8 @@ class ModelInstance:
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
@ -369,6 +368,15 @@ class ModelManager:
return ModelInstance(provider_model_bundle, model)
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
"""
Return first provider and the first model in the provider
:param tenant_id: tenant id
:param model_type: model type
:return: provider name, model name
"""
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
@ -401,6 +409,10 @@ class LBModelManager:
managed_credentials: Optional[dict] = None) -> None:
"""
Load balancing model manager
:param tenant_id: tenant_id
:param provider: provider
:param model_type: model_type
:param model: model name
:param load_balancing_configs: all load balancing configurations
:param managed_credentials: credentials if load balancing configuration name is __inherit__
"""
@ -499,7 +511,6 @@ class LBModelManager:
config.id
)
res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res

View File

@ -1,4 +1,3 @@
from core.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
@ -94,5 +93,16 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
'required': False,
'options': ['JSON', 'XML'],
}
},
DefaultParameterName.JSON_SCHEMA: {
'label': {
'en_US': 'JSON Schema',
},
'type': 'text',
'help': {
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
'zh_Hans': '设置返回的json schemallm将按照它返回',
},
'required': False,
},
}

View File

@ -95,6 +95,7 @@ class DefaultParameterName(Enum):
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
@ -118,6 +119,7 @@ class ParameterType(Enum):
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
class ModelPropertyKey(Enum):

View File

@ -151,9 +151,9 @@ class AIModel(ABC):
os.path.join(provider_model_type_path, model_schema_yaml)
for model_schema_yaml in os.listdir(provider_model_type_path)
if not model_schema_yaml.startswith('__')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
and not model_schema_yaml.startswith('_')
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
and model_schema_yaml.endswith('.yaml')
]
# get _position.yaml file path

View File

@ -792,6 +792,13 @@ if you are not sure about the structure.
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be string.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
elif parameter_rule.type == ParameterType.TEXT:
if not isinstance(parameter_value, str):
raise ValueError(f"Model Parameter {parameter_name} should be text.")
# validate options
if parameter_rule.options and parameter_value not in parameter_rule.options:
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")

View File

@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
# doc: https://platform.openai.com/docs/guides/text-to-speech
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
# max font is 4096,there is 3500 limit for each request
# max length is 4096 characters, there is 3500 limit for each request
max_length = 3500
if len(content_text) > max_length:
sentences = self._split_text_into_sentences(content_text, max_length=max_length)

View File

@ -6,7 +6,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -234,7 +234,7 @@ class ModelProviderFactory:
]
# get _position.yaml file path
position_map = get_position_map(model_providers_path)
position_map = get_provider_position_map(model_providers_path)
# traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []

View File

@ -84,7 +84,8 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)

View File

@ -31,6 +31,14 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: Base URL, 如https://api.moonshot.cn/v1
en_US: Base URL, e.g. https://api.moonshot.cn/v1
model_credential_schema:
model:
label:

View File

@ -2,6 +2,7 @@
- gpt-4o
- gpt-4o-2024-05-13
- gpt-4o-2024-08-06
- chatgpt-4o-latest
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- gpt-4-turbo

View File

@ -0,0 +1,44 @@
model: chatgpt-4o-latest
label:
zh_Hans: chatgpt-4o-latest
en_US: chatgpt-4o-latest
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '2.50'
output: '10.00'

View File

@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '0.15'
output: '0.60'

View File

@ -1,3 +1,4 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
@ -544,13 +545,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_object":
response_format = {"type": "json_object"}
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not currect json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
response_format = {"type": "text"}
model_parameters["response_format"] = response_format
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}
@ -922,11 +928,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model.startswith('ft:'):
model = model.split(':')[1]
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
if model == "chatgpt-4o-latest":
model = "gpt-4o"
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@ -946,7 +955,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"See https://platform.openai.com/docs/advanced-usage/managing-tokens for "
"information on how messages are converted to tokens."
)
num_tokens = 0

View File

@ -0,0 +1,44 @@
model: gpt-4o-2024-08-06
label:
zh_Hans: gpt-4o-2024-08-06
en_US: gpt-4o-2024-08-06
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,61 @@
model: Llama3-Chinese_v2
label:
en_US: Llama3-Chinese_v2
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3-70B-Instruct-GPTQ-Int4
label:
en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 1024
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3-8B-Instruct
label:
en_US: Meta-Llama-3-8B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
label:
en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 410960
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3.1-8B-Instruct
label:
en_US: Meta-Llama-3.1-8B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.1
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -55,7 +55,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB
deprecated: true

View File

@ -55,7 +55,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB
deprecated: true

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
@ -55,7 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: completion
context_size: 8192
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
@ -55,7 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -8,12 +8,12 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 8192
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
default: 0.7
min: 0.0
max: 2.0
help:
@ -57,7 +57,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Qwen2-72B-Instruct
label:
en_US: Qwen2-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -8,7 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: completion
context_size: 8192
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
@ -57,7 +57,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,6 +1,15 @@
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
- Meta-Llama-3.1-8B-Instruct
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
- Meta-Llama-3-8B-Instruct
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-72B-Instruct
- Qwen2-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- Qwen-14B-Chat-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen-14B-Chat-Int4
- Qwen1.5-110B-Chat-GPTQ-Int4
- deepseek-v2-chat
- deepseek-v2-lite-chat
- Llama3-Chinese_v2
- chatglm3-6b

View File

@ -0,0 +1,61 @@
model: chatglm3-6b
label:
en_US: chatglm3-6b
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: deepseek-v2-chat
label:
en_US: deepseek-v2-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: deepseek-v2-lite-chat
label:
en_US: deepseek-v2-lite-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,4 @@
model: BAAI/bge-large-en-v1.5
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -0,0 +1,4 @@
model: BAAI/bge-large-zh-v1.5
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -0,0 +1,4 @@
model: netease-youdao/bce-reranker-base_v1
model_type: rerank
model_properties:
context_size: 512

View File

@ -0,0 +1,4 @@
model: BAAI/bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 8192

View File

@ -0,0 +1,87 @@
from typing import Optional
import httpx
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class SiliconflowRerankModel(RerankModel):
def _invoke(self, model: str, credentials: dict, query: str, docs: list[str],
score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
if len(docs) == 0:
return RerankResult(model=model, docs=[])
base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1')
if base_url.endswith('/'):
base_url = base_url[:-1]
try:
response = httpx.post(
base_url + '/rerank',
json={
"model": model,
"query": query,
"documents": docs,
"top_n": top_n,
"return_documents": True
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

View File

@ -12,10 +12,11 @@ help:
en_US: Get your API Key from SiliconFlow
zh_Hans: 从 SiliconFlow 获取 API Key
url:
en_US: https://cloud.siliconflow.cn/keys
en_US: https://cloud.siliconflow.cn/account/ak
supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
configurate_methods:
- predefined-model

View File

@ -159,6 +159,8 @@ You should also complete the text started with ``` but not tell ``` directly.
"""
if model in ['qwen-turbo-chat', 'qwen-plus-chat']:
model = model.replace('-chat', '')
if model == 'farui-plus':
model = 'qwen-farui-plus'
if model in self.tokenizers:
tokenizer = self.tokenizers[model]

View File

@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
get_model_config,
get_v2_req_params,
)
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
logger = logging.getLogger(__name__)
@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
-> LLMResult | Generator:
client = MaaSClient.from_credential(credentials)
req_params = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).copy()
if credentials.get('context_size'):
req_params['max_prompt_tokens'] = credentials.get('context_size')
if credentials.get('max_tokens'):
req_params['max_new_tokens'] = credentials.get('max_tokens')
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
req_params = get_v2_req_params(credentials, model_parameters, stop)
extra_model_kwargs = {}
if tools:
extra_model_kwargs['tools'] = [
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
]
resp = MaaSClient.wrap_exception(
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
if not stream:
@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
"""
used to define customizable model schema
"""
max_tokens = ModelConfigs.get(
credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
if credentials.get('max_tokens'):
max_tokens = int(credentials.get('max_tokens'))
model_config = get_model_config(credentials)
rules = [
ParameterRule(
name='temperature',
@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='presence_penalty',
type=ParameterType.FLOAT,
use_template='presence_penalty',
label={
'en_US': 'Presence Penalty',
'zh_Hans': '存在惩罚',
},
label=I18nObject(
en_US='Presence Penalty',
zh_Hans= '存在惩罚',
),
min=-2.0,
max=2.0,
),
@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
name='frequency_penalty',
type=ParameterType.FLOAT,
use_template='frequency_penalty',
label={
'en_US': 'Frequency Penalty',
'zh_Hans': '频率惩罚',
},
label=I18nObject(
en_US= 'Frequency Penalty',
zh_Hans= '频率惩罚',
),
min=-2.0,
max=2.0,
),
@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=max_tokens,
max=model_config.properties.max_tokens,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
@ -266,16 +242,9 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
),
]
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('mode'):
model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
model_features = ModelConfigs.get(
credentials['base_model_name'], {}).get('features', [])
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
entity = AIModelEntity(
model=model,
@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model_type=ModelType.LLM,
model_properties=model_properties,
parameter_rules=rules,
features=model_features,
features=model_config.features,
)
return entity

View File

@ -1,181 +1,123 @@
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature
ModelConfigs = {
'Doubao-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 32768,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-pro-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Doubao-lite-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 131072,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [
ModelFeature.TOOL_CALL
],
},
'Skylark2-pro-4k': {
'req_params': {
'max_prompt_tokens': 4096,
'max_new_tokens': 4000,
},
'model_properties': {
'context_size': 4096,
'mode': 'chat',
},
'features': [],
},
'Llama3-8B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Llama3-70B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 8192,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-8k': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-32k': {
'req_params': {
'max_prompt_tokens': 32768,
'max_new_tokens': 16384,
},
'model_properties': {
'context_size': 32768,
'mode': 'chat',
},
'features': [],
},
'Moonshot-v1-128k': {
'req_params': {
'max_prompt_tokens': 131072,
'max_new_tokens': 65536,
},
'model_properties': {
'context_size': 131072,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'GLM3-130B-Fin': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 4096,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
},
'Mistral-7B': {
'req_params': {
'max_prompt_tokens': 8192,
'max_new_tokens': 2048,
},
'model_properties': {
'context_size': 8192,
'mode': 'chat',
},
'features': [],
}
class ModelProperties(BaseModel):
context_size: int
max_tokens: int
mode: LLMMode
class ModelConfig(BaseModel):
properties: ModelProperties
features: list[ModelFeature]
configs: dict[str, ModelConfig] = {
'Doubao-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-pro-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Doubao-lite-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL]
),
'Skylark2-pro-4k': ModelConfig(
properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
features=[]
),
'Llama3-8B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Llama3-70B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-8k': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-32k': ModelConfig(
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
features=[]
),
'Moonshot-v1-128k': ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'GLM3-130B-Fin': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
features=[]
),
'Mistral-7B': ModelConfig(
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
features=[]
)
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = configs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_tokens=int(credentials.get('max_tokens', 0)),
mode= LLMMode.value_of(credentials.get('mode', 'chat')),
),
features=[]
)
return model_configs
def get_v2_req_params(credentials: dict, model_parameters: dict,
stop: list[str] | None=None):
req_params = {}
# predefined properties
model_configs = get_model_config(credentials)
if model_configs:
req_params['max_prompt_tokens'] = model_configs.properties.context_size
req_params['max_new_tokens'] = model_configs.properties.max_tokens
# model parameters
if model_parameters.get('max_tokens'):
req_params['max_new_tokens'] = model_parameters.get('max_tokens')
if model_parameters.get('temperature'):
req_params['temperature'] = model_parameters.get('temperature')
if model_parameters.get('top_p'):
req_params['top_p'] = model_parameters.get('top_p')
if model_parameters.get('top_k'):
req_params['top_k'] = model_parameters.get('top_k')
if model_parameters.get('presence_penalty'):
req_params['presence_penalty'] = model_parameters.get(
'presence_penalty')
if model_parameters.get('frequency_penalty'):
req_params['frequency_penalty'] = model_parameters.get(
'frequency_penalty')
if stop:
req_params['stop'] = stop
return req_params

View File

@ -1,9 +1,27 @@
from pydantic import BaseModel
class ModelProperties(BaseModel):
context_size: int
max_chunks: int
class ModelConfig(BaseModel):
properties: ModelProperties
ModelConfigs = {
'Doubao-embedding': {
'req_params': {},
'model_properties': {
'context_size': 4096,
'max_chunks': 1,
}
},
'Doubao-embedding': ModelConfig(
properties=ModelProperties(context_size=4096, max_chunks=1)
),
}
def get_model_config(credentials: dict)->ModelConfig:
base_model = credentials.get('base_model_name', '')
model_configs = ModelConfigs.get(base_model)
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get('context_size', 0)),
max_chunks=int(credentials.get('max_chunks', 0)),
)
)
return model_configs

View File

@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
RateLimitErrors,
ServerUnavailableErrors,
)
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
"""
generate custom model entities from credentials
"""
model_properties = ModelConfigs.get(
credentials['base_model_name'], {}).get('model_properties', {}).copy()
if credentials.get('context_size'):
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
credentials.get('context_size', 4096))
if credentials.get('max_chunks'):
model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
credentials.get('max_chunks', 4096))
model_config = get_model_config(credentials)
model_properties = {}
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),

View File

@ -0,0 +1,198 @@
from datetime import datetime, timedelta
from threading import Lock
from requests import post
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
@staticmethod
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class _CommonWenxin:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en',
'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh',
'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
credentials_kwargs = {
"api_key": credentials['api_key'],
"secret_key": credentials['secret_key']
}
return credentials_kwargs
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

View File

@ -1,102 +1,17 @@
from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum
from json import dumps, loads
from threading import Lock
from typing import Any, Union
from requests import Response, post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
# map api_key to access_token
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class ErnieMessage:
class Role(Enum):
@ -120,51 +35,7 @@ class ErnieMessage:
self.content = content
self.role = role
class ErnieBotModel:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
class ErnieBotModel(_CommonWenxin):
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
@ -199,51 +70,6 @@ class ErnieBotModel:
return self._handle_chat_stream_generate_response(resp)
return self._handle_chat_generate_response(resp)
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]

View File

@ -1,17 +0,0 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
BadRequestError,
InsufficientAccountBalance,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken._get_access_token(api_key, secret_key)
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
return invoke_error_mapping()

View File

@ -0,0 +1,9 @@
model: bge-large-en
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: bge-large-zh
model_type: text-embedding
model_properties:
context_size: 512
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: embedding-v1
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,9 @@
model: tao-8k
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 1
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,184 @@
import time
from abc import abstractmethod
from collections.abc import Mapping
from json import dumps
from typing import Any, Optional
import numpy as np
from requests import Response, post
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
invoke_error_mapping,
)
class TextEmbedding:
@abstractmethod
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
raise NotImplementedError
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
access_token = self._get_access_token()
url = f'{self.api_bases[model]}?access_token={access_token}'
body = self._build_embed_request_body(model, texts, user)
headers = {
'Content-Type': 'application/json',
}
resp = post(url, data=dumps(body), headers=headers)
if resp.status_code != 200:
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
return self._handle_embed_response(model, resp)
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
if len(texts) == 0:
raise BadRequestError('The number of texts should not be zero.')
body = {
'input': texts,
'user_id': user,
}
return body
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
data = response.json()
if 'error_code' in data:
code = data['error_code']
msg = data['error_msg']
# raise error
self._handle_error(code, msg)
embeddings = [v['embedding'] for v in data['data']]
_usage = data['usage']
tokens = _usage['prompt_tokens']
total_tokens = _usage['total_tokens']
return embeddings, tokens, total_tokens
class WenxinTextEmbeddingModel(TextEmbeddingModel):
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
return WenxinTextEmbedding(api_key, secret_key)
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
secret_key = credentials['secret_key']
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else 'ErnieBotDefault'
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
used_total_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
model,
inputs[i: i + max_chunks],
user)
used_tokens += _used_tokens
used_total_tokens += _total_used_tokens
batched_embeddings += embeddings_batch
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
return TextEmbeddingResult(
model=model,
embeddings=batched_embeddings,
usage=usage,
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
total_num_tokens = 0
for text in texts:
total_num_tokens += self._get_num_tokens_by_gpt2(text)
return total_num_tokens
def validate_credentials(self, model: str, credentials: Mapping) -> None:
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return invoke_error_mapping()
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage

View File

@ -17,6 +17,7 @@ help:
en_US: https://cloud.baidu.com/wenxin.html
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,57 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
tools=tools, stop=stop, stream=stream, user=user,
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
)
)
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
)
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
)
if 'chat' in extra_args.model_ability:
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
xinference_client = Client(
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)
xinference_model = xinference_client.get_model(credentials['model_uid'])

View File

@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
# initialize client
client = Client(
base_url=credentials['server_url']
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])

View File

@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
# initialize client
client = Client(
base_url=credentials['server_url']
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])

View File

@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
api_key = credentials.get('api_key')
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=server_url,
model_uid=model_uid,
api_key=api_key,
)
if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'):
server_url = server_url[:-1]
client = Client(base_url=server_url)
client = Client(
base_url=server_url,
api_key=api_key,
)
try:
handle = client.get_model(model_uid=model_uid)

View File

@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
)
if 'text-to-audio' not in extra_param.model_ability:
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
credentials['server_url'] = credentials['server_url'][:-1]
try:
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
api_key = credentials.get('api_key')
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
)
model_support_voice = [x.get("value") for x in
self.get_tts_model_voices(model=model, credentials=credentials)]

Some files were not shown because too many files have changed in this diff Show More