mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 01:25:53 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/base_app_runner.py # api/core/app/apps/workflow/app_runner.py # api/core/app/apps/workflow/generate_task_pipeline.py # api/core/app/task_pipeline/workflow_cycle_state_manager.py # api/core/workflow/entities/node_entities.py # api/core/workflow/nodes/llm/llm_node.py # api/core/workflow/workflow_engine_manager.py # api/tests/integration_tests/workflow/nodes/test_llm.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
commit
db9b0ee985
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@ -76,7 +76,7 @@ jobs:
|
|||||||
- name: Run Workflow
|
- name: Run Workflow
|
||||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||||
|
|
||||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
|
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
|
||||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||||
with:
|
with:
|
||||||
compose-file: |
|
compose-file: |
|
||||||
@ -90,5 +90,6 @@ jobs:
|
|||||||
pgvecto-rs
|
pgvecto-rs
|
||||||
pgvector
|
pgvector
|
||||||
chroma
|
chroma
|
||||||
|
elasticsearch
|
||||||
- name: Test Vector Stores
|
- name: Test Vector Stores
|
||||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||||
|
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
|||||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
|
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
|
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
|
||||||
|
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||||
|
|
||||||
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs."
|
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"
|
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -45,6 +45,10 @@ jobs:
|
|||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||||
|
|
||||||
|
- name: Ruff formatter check
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
run: poetry run -C api ruff format --check ./api
|
||||||
|
|
||||||
- name: Lint hints
|
- name: Lint hints
|
||||||
if: failure()
|
if: failure()
|
||||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||||
|
@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify
|
|||||||
TENCENT_VECTOR_DB_SHARD=1
|
TENCENT_VECTOR_DB_SHARD=1
|
||||||
TENCENT_VECTOR_DB_REPLICAS=2
|
TENCENT_VECTOR_DB_REPLICAS=2
|
||||||
|
|
||||||
|
# ElasticSearch configuration
|
||||||
|
ELASTICSEARCH_HOST=127.0.0.1
|
||||||
|
ELASTICSEARCH_PORT=9200
|
||||||
|
ELASTICSEARCH_USERNAME=elastic
|
||||||
|
ELASTICSEARCH_PASSWORD=elastic
|
||||||
|
|
||||||
# PGVECTO_RS configuration
|
# PGVECTO_RS configuration
|
||||||
PGVECTO_RS_HOST=localhost
|
PGVECTO_RS_HOST=localhost
|
||||||
PGVECTO_RS_PORT=5431
|
PGVECTO_RS_PORT=5431
|
||||||
|
151
api/app.py
151
api/app.py
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
if os.environ.get("DEBUG", "false").lower() != 'true':
|
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||||
from gevent import monkey
|
from gevent import monkey
|
||||||
|
|
||||||
monkey.patch_all()
|
monkey.patch_all()
|
||||||
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
|
|||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
os.system('tzutil /s "UTC"')
|
os.system('tzutil /s "UTC"')
|
||||||
else:
|
else:
|
||||||
os.environ['TZ'] = 'UTC'
|
os.environ["TZ"] = "UTC"
|
||||||
time.tzset()
|
time.tzset()
|
||||||
|
|
||||||
|
|
||||||
@ -70,13 +70,14 @@ class DifyApp(Flask):
|
|||||||
# -------------
|
# -------------
|
||||||
|
|
||||||
|
|
||||||
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
|
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# Application Factory Function
|
# Application Factory Function
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
|
||||||
|
|
||||||
def create_flask_app_with_configs() -> Flask:
|
def create_flask_app_with_configs() -> Flask:
|
||||||
"""
|
"""
|
||||||
create a raw flask app
|
create a raw flask app
|
||||||
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
|
|||||||
elif isinstance(value, int | float | bool):
|
elif isinstance(value, int | float | bool):
|
||||||
os.environ[key] = str(value)
|
os.environ[key] = str(value)
|
||||||
elif value is None:
|
elif value is None:
|
||||||
os.environ[key] = ''
|
os.environ[key] = ""
|
||||||
|
|
||||||
return dify_app
|
return dify_app
|
||||||
|
|
||||||
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
|
|||||||
def create_app() -> Flask:
|
def create_app() -> Flask:
|
||||||
app = create_flask_app_with_configs()
|
app = create_flask_app_with_configs()
|
||||||
|
|
||||||
app.secret_key = app.config['SECRET_KEY']
|
app.secret_key = app.config["SECRET_KEY"]
|
||||||
|
|
||||||
log_handlers = None
|
log_handlers = None
|
||||||
log_file = app.config.get('LOG_FILE')
|
log_file = app.config.get("LOG_FILE")
|
||||||
if log_file:
|
if log_file:
|
||||||
log_dir = os.path.dirname(log_file)
|
log_dir = os.path.dirname(log_file)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
@ -111,23 +112,24 @@ def create_app() -> Flask:
|
|||||||
RotatingFileHandler(
|
RotatingFileHandler(
|
||||||
filename=log_file,
|
filename=log_file,
|
||||||
maxBytes=1024 * 1024 * 1024,
|
maxBytes=1024 * 1024 * 1024,
|
||||||
backupCount=5
|
backupCount=5,
|
||||||
),
|
),
|
||||||
logging.StreamHandler(sys.stdout)
|
logging.StreamHandler(sys.stdout),
|
||||||
]
|
]
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=app.config.get('LOG_LEVEL'),
|
level=app.config.get("LOG_LEVEL"),
|
||||||
format=app.config.get('LOG_FORMAT'),
|
format=app.config.get("LOG_FORMAT"),
|
||||||
datefmt=app.config.get('LOG_DATEFORMAT'),
|
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||||
handlers=log_handlers,
|
handlers=log_handlers,
|
||||||
force=True
|
force=True,
|
||||||
)
|
)
|
||||||
log_tz = app.config.get('LOG_TZ')
|
log_tz = app.config.get("LOG_TZ")
|
||||||
if log_tz:
|
if log_tz:
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
timezone = pytz.timezone(log_tz)
|
timezone = pytz.timezone(log_tz)
|
||||||
|
|
||||||
def time_converter(seconds):
|
def time_converter(seconds):
|
||||||
@ -162,24 +164,24 @@ def initialize_extensions(app):
|
|||||||
@login_manager.request_loader
|
@login_manager.request_loader
|
||||||
def load_user_from_request(request_from_flask_login):
|
def load_user_from_request(request_from_flask_login):
|
||||||
"""Load user based on the request."""
|
"""Load user based on the request."""
|
||||||
if request.blueprint not in ['console', 'inner_api']:
|
if request.blueprint not in ["console", "inner_api"]:
|
||||||
return None
|
return None
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
auth_header = request.headers.get('Authorization', '')
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if not auth_header:
|
if not auth_header:
|
||||||
auth_token = request.args.get('_token')
|
auth_token = request.args.get("_token")
|
||||||
if not auth_token:
|
if not auth_token:
|
||||||
raise Unauthorized('Invalid Authorization token.')
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
else:
|
else:
|
||||||
if ' ' not in auth_header:
|
if " " not in auth_header:
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
auth_scheme = auth_scheme.lower()
|
auth_scheme = auth_scheme.lower()
|
||||||
if auth_scheme != 'bearer':
|
if auth_scheme != "bearer":
|
||||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
|
||||||
decoded = PassportService().verify(auth_token)
|
decoded = PassportService().verify(auth_token)
|
||||||
user_id = decoded.get('user_id')
|
user_id = decoded.get("user_id")
|
||||||
|
|
||||||
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||||
if account:
|
if account:
|
||||||
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
|
|||||||
@login_manager.unauthorized_handler
|
@login_manager.unauthorized_handler
|
||||||
def unauthorized_handler():
|
def unauthorized_handler():
|
||||||
"""Handle unauthorized requests."""
|
"""Handle unauthorized requests."""
|
||||||
return Response(json.dumps({
|
return Response(
|
||||||
'code': 'unauthorized',
|
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||||
'message': "Unauthorized."
|
status=401,
|
||||||
}), status=401, content_type="application/json")
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# register blueprint routers
|
# register blueprint routers
|
||||||
@ -204,38 +207,36 @@ def register_blueprints(app):
|
|||||||
from controllers.service_api import bp as service_api_bp
|
from controllers.service_api import bp as service_api_bp
|
||||||
from controllers.web import bp as web_bp
|
from controllers.web import bp as web_bp
|
||||||
|
|
||||||
CORS(service_api_bp,
|
CORS(
|
||||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
service_api_bp,
|
||||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
)
|
)
|
||||||
app.register_blueprint(service_api_bp)
|
app.register_blueprint(service_api_bp)
|
||||||
|
|
||||||
CORS(web_bp,
|
CORS(
|
||||||
resources={
|
web_bp,
|
||||||
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
|
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||||
supports_credentials=True,
|
supports_credentials=True,
|
||||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
expose_headers=['X-Version', 'X-Env']
|
expose_headers=["X-Version", "X-Env"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.register_blueprint(web_bp)
|
app.register_blueprint(web_bp)
|
||||||
|
|
||||||
CORS(console_app_bp,
|
CORS(
|
||||||
resources={
|
console_app_bp,
|
||||||
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
|
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||||
supports_credentials=True,
|
supports_credentials=True,
|
||||||
allow_headers=['Content-Type', 'Authorization'],
|
allow_headers=["Content-Type", "Authorization"],
|
||||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
expose_headers=['X-Version', 'X-Env']
|
expose_headers=["X-Version", "X-Env"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.register_blueprint(console_app_bp)
|
app.register_blueprint(console_app_bp)
|
||||||
|
|
||||||
CORS(files_bp,
|
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||||
allow_headers=['Content-Type'],
|
|
||||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
|
||||||
)
|
|
||||||
app.register_blueprint(files_bp)
|
app.register_blueprint(files_bp)
|
||||||
|
|
||||||
app.register_blueprint(inner_api_bp)
|
app.register_blueprint(inner_api_bp)
|
||||||
@ -245,29 +246,29 @@ def register_blueprints(app):
|
|||||||
app = create_app()
|
app = create_app()
|
||||||
celery = app.extensions["celery"]
|
celery = app.extensions["celery"]
|
||||||
|
|
||||||
if app.config.get('TESTING'):
|
if app.config.get("TESTING"):
|
||||||
print("App is running in TESTING mode")
|
print("App is running in TESTING mode")
|
||||||
|
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
def after_request(response):
|
def after_request(response):
|
||||||
"""Add Version headers to the response."""
|
"""Add Version headers to the response."""
|
||||||
response.set_cookie('remember_token', '', expires=0)
|
response.set_cookie("remember_token", "", expires=0)
|
||||||
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
|
||||||
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.route('/health')
|
@app.route("/health")
|
||||||
def health():
|
def health():
|
||||||
return Response(json.dumps({
|
return Response(
|
||||||
'pid': os.getpid(),
|
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
|
||||||
'status': 'ok',
|
status=200,
|
||||||
'version': app.config['CURRENT_VERSION']
|
content_type="application/json",
|
||||||
}), status=200, content_type="application/json")
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/threads')
|
@app.route("/threads")
|
||||||
def threads():
|
def threads():
|
||||||
num_threads = threading.active_count()
|
num_threads = threading.active_count()
|
||||||
threads = threading.enumerate()
|
threads = threading.enumerate()
|
||||||
@ -278,32 +279,34 @@ def threads():
|
|||||||
thread_id = thread.ident
|
thread_id = thread.ident
|
||||||
is_alive = thread.is_alive()
|
is_alive = thread.is_alive()
|
||||||
|
|
||||||
thread_list.append({
|
thread_list.append(
|
||||||
'name': thread_name,
|
{
|
||||||
'id': thread_id,
|
"name": thread_name,
|
||||||
'is_alive': is_alive
|
"id": thread_id,
|
||||||
})
|
"is_alive": is_alive,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'pid': os.getpid(),
|
"pid": os.getpid(),
|
||||||
'thread_num': num_threads,
|
"thread_num": num_threads,
|
||||||
'threads': thread_list
|
"threads": thread_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.route('/db-pool-stat')
|
@app.route("/db-pool-stat")
|
||||||
def pool_stat():
|
def pool_stat():
|
||||||
engine = db.engine
|
engine = db.engine
|
||||||
return {
|
return {
|
||||||
'pid': os.getpid(),
|
"pid": os.getpid(),
|
||||||
'pool_size': engine.pool.size(),
|
"pool_size": engine.pool.size(),
|
||||||
'checked_in_connections': engine.pool.checkedin(),
|
"checked_in_connections": engine.pool.checkedin(),
|
||||||
'checked_out_connections': engine.pool.checkedout(),
|
"checked_out_connections": engine.pool.checkedout(),
|
||||||
'overflow_connections': engine.pool.overflow(),
|
"overflow_connections": engine.pool.overflow(),
|
||||||
'connection_timeout': engine.pool.timeout(),
|
"connection_timeout": engine.pool.timeout(),
|
||||||
'recycle_time': db.engine.pool._recycle
|
"recycle_time": db.engine.pool._recycle,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
app.run(host='0.0.0.0', port=5001)
|
app.run(host="0.0.0.0", port=5001)
|
||||||
|
407
api/commands.py
407
api/commands.py
@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
|
|||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
|
|
||||||
|
|
||||||
@click.command('reset-password', help='Reset the account password.')
|
@click.command("reset-password", help="Reset the account password.")
|
||||||
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
|
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
|
||||||
@click.option('--new-password', prompt=True, help='the new password.')
|
@click.option("--new-password", prompt=True, help="the new password.")
|
||||||
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
|
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
|
||||||
def reset_password(email, new_password, password_confirm):
|
def reset_password(email, new_password, password_confirm):
|
||||||
"""
|
"""
|
||||||
Reset password of owner account
|
Reset password of owner account
|
||||||
Only available in SELF_HOSTED mode
|
Only available in SELF_HOSTED mode
|
||||||
"""
|
"""
|
||||||
if str(new_password).strip() != str(password_confirm).strip():
|
if str(new_password).strip() != str(password_confirm).strip():
|
||||||
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
|
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
account = db.session.query(Account). \
|
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||||
filter(Account.email == email). \
|
|
||||||
one_or_none()
|
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
valid_password(new_password)
|
valid_password(new_password)
|
||||||
except:
|
except:
|
||||||
click.echo(
|
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
|
||||||
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# generate password salt
|
# generate password salt
|
||||||
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
|
|||||||
account.password = base64_password_hashed
|
account.password = base64_password_hashed
|
||||||
account.password_salt = base64_salt
|
account.password_salt = base64_salt
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
|
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command('reset-email', help='Reset the account email.')
|
@click.command("reset-email", help="Reset the account email.")
|
||||||
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
|
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
|
||||||
@click.option('--new-email', prompt=True, help='the new email.')
|
@click.option("--new-email", prompt=True, help="the new email.")
|
||||||
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
|
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
|
||||||
def reset_email(email, new_email, email_confirm):
|
def reset_email(email, new_email, email_confirm):
|
||||||
"""
|
"""
|
||||||
Replace account email
|
Replace account email
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if str(new_email).strip() != str(email_confirm).strip():
|
if str(new_email).strip() != str(email_confirm).strip():
|
||||||
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
|
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
account = db.session.query(Account). \
|
account = db.session.query(Account).filter(Account.email == email).one_or_none()
|
||||||
filter(Account.email == email). \
|
|
||||||
one_or_none()
|
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
|
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
email_validate(new_email)
|
email_validate(new_email)
|
||||||
except:
|
except:
|
||||||
click.echo(
|
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
|
||||||
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
account.email = new_email
|
account.email = new_email
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
|
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
|
@click.command(
|
||||||
'After the reset, all LLM credentials will become invalid, '
|
"reset-encrypt-key-pair",
|
||||||
'requiring re-entry.'
|
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||||
'Only support SELF_HOSTED mode.')
|
"After the reset, all LLM credentials will become invalid, "
|
||||||
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
|
"requiring re-entry."
|
||||||
' this operation cannot be rolled back!', fg='red'))
|
"Only support SELF_HOSTED mode.",
|
||||||
|
)
|
||||||
|
@click.confirmation_option(
|
||||||
|
prompt=click.style(
|
||||||
|
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
|
||||||
|
)
|
||||||
|
)
|
||||||
def reset_encrypt_key_pair():
|
def reset_encrypt_key_pair():
|
||||||
"""
|
"""
|
||||||
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
Reset the encrypted key pair of workspace for encrypt LLM credentials.
|
||||||
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
After the reset, all LLM credentials will become invalid, requiring re-entry.
|
||||||
Only support SELF_HOSTED mode.
|
Only support SELF_HOSTED mode.
|
||||||
"""
|
"""
|
||||||
if dify_config.EDITION != 'SELF_HOSTED':
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
|
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
tenants = db.session.query(Tenant).all()
|
tenants = db.session.query(Tenant).all()
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if not tenant:
|
if not tenant:
|
||||||
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
|
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||||
|
|
||||||
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
|
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||||
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
|
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
click.echo(click.style('Congratulations! '
|
click.echo(
|
||||||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
click.style(
|
||||||
|
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@click.command('vdb-migrate', help='migrate vector db.')
|
@click.command("vdb-migrate", help="migrate vector db.")
|
||||||
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
|
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||||
def vdb_migrate(scope: str):
|
def vdb_migrate(scope: str):
|
||||||
if scope in ['knowledge', 'all']:
|
if scope in ["knowledge", "all"]:
|
||||||
migrate_knowledge_vector_database()
|
migrate_knowledge_vector_database()
|
||||||
if scope in ['annotation', 'all']:
|
if scope in ["annotation", "all"]:
|
||||||
migrate_annotation_vector_database()
|
migrate_annotation_vector_database()
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
|
|||||||
"""
|
"""
|
||||||
Migrate annotation datas to target vector database .
|
Migrate annotation datas to target vector database .
|
||||||
"""
|
"""
|
||||||
click.echo(click.style('Start migrate annotation data.', fg='green'))
|
click.echo(click.style("Start migrate annotation data.", fg="green"))
|
||||||
create_count = 0
|
create_count = 0
|
||||||
skipped_count = 0
|
skipped_count = 0
|
||||||
total_count = 0
|
total_count = 0
|
||||||
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# get apps info
|
# get apps info
|
||||||
apps = db.session.query(App).filter(
|
apps = (
|
||||||
App.status == 'normal'
|
db.session.query(App)
|
||||||
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
|
.filter(App.status == "normal")
|
||||||
|
.order_by(App.created_at.desc())
|
||||||
|
.paginate(page=page, per_page=50)
|
||||||
|
)
|
||||||
except NotFound:
|
except NotFound:
|
||||||
break
|
break
|
||||||
|
|
||||||
page += 1
|
page += 1
|
||||||
for app in apps:
|
for app in apps:
|
||||||
total_count = total_count + 1
|
total_count = total_count + 1
|
||||||
click.echo(f'Processing the {total_count} app {app.id}. '
|
click.echo(
|
||||||
+ f'{create_count} created, {skipped_count} skipped.')
|
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
click.echo('Create app annotation index: {}'.format(app.id))
|
click.echo("Create app annotation index: {}".format(app.id))
|
||||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
app_annotation_setting = (
|
||||||
AppAnnotationSetting.app_id == app.id
|
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
|
||||||
).first()
|
)
|
||||||
|
|
||||||
if not app_annotation_setting:
|
if not app_annotation_setting:
|
||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
click.echo('App annotation setting is disabled: {}'.format(app.id))
|
click.echo("App annotation setting is disabled: {}".format(app.id))
|
||||||
continue
|
continue
|
||||||
# get dataset_collection_binding info
|
# get dataset_collection_binding info
|
||||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
|
dataset_collection_binding = (
|
||||||
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
|
db.session.query(DatasetCollectionBinding)
|
||||||
).first()
|
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not dataset_collection_binding:
|
if not dataset_collection_binding:
|
||||||
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
|
click.echo("App annotation collection binding is not exist: {}".format(app.id))
|
||||||
continue
|
continue
|
||||||
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
id=app.id,
|
id=app.id,
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
indexing_technique='high_quality',
|
indexing_technique="high_quality",
|
||||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||||
embedding_model=dataset_collection_binding.model_name,
|
embedding_model=dataset_collection_binding.model_name,
|
||||||
collection_binding_id=dataset_collection_binding.id
|
collection_binding_id=dataset_collection_binding.id,
|
||||||
)
|
)
|
||||||
documents = []
|
documents = []
|
||||||
if annotations:
|
if annotations:
|
||||||
for annotation in annotations:
|
for annotation in annotations:
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=annotation.question,
|
page_content=annotation.question,
|
||||||
metadata={
|
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
|
||||||
"annotation_id": annotation.id,
|
|
||||||
"app_id": app.id,
|
|
||||||
"doc_id": annotation.id
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
|
|
||||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||||
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vector.delete()
|
vector.delete()
|
||||||
click.echo(
|
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
|
||||||
click.style(f'Successfully delete vector index for app: {app.id}.',
|
|
||||||
fg='green'))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(
|
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
|
||||||
click.style(f'Failed to delete vector index for app {app.id}.',
|
|
||||||
fg='red'))
|
|
||||||
raise e
|
raise e
|
||||||
if documents:
|
if documents:
|
||||||
try:
|
try:
|
||||||
click.echo(click.style(
|
|
||||||
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
|
|
||||||
fg='green'))
|
|
||||||
vector.create(documents)
|
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
|
click.style(
|
||||||
|
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vector.create(documents)
|
||||||
|
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
|
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
|
||||||
raise e
|
raise e
|
||||||
click.echo(f'Successfully migrated app annotation {app.id}.')
|
click.echo(f"Successfully migrated app annotation {app.id}.")
|
||||||
create_count += 1
|
create_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
|
click.style(
|
||||||
fg='red'))
|
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
|
click.style(
|
||||||
fg='green'))
|
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def migrate_knowledge_vector_database():
|
def migrate_knowledge_vector_database():
|
||||||
"""
|
"""
|
||||||
Migrate vector database datas to target vector database .
|
Migrate vector database datas to target vector database .
|
||||||
"""
|
"""
|
||||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
click.echo(click.style("Start migrate vector db.", fg="green"))
|
||||||
create_count = 0
|
create_count = 0
|
||||||
skipped_count = 0
|
skipped_count = 0
|
||||||
total_count = 0
|
total_count = 0
|
||||||
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
|
|||||||
page = 1
|
page = 1
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
datasets = (
|
||||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
db.session.query(Dataset)
|
||||||
|
.filter(Dataset.indexing_technique == "high_quality")
|
||||||
|
.order_by(Dataset.created_at.desc())
|
||||||
|
.paginate(page=page, per_page=50)
|
||||||
|
)
|
||||||
except NotFound:
|
except NotFound:
|
||||||
break
|
break
|
||||||
|
|
||||||
page += 1
|
page += 1
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
total_count = total_count + 1
|
total_count = total_count + 1
|
||||||
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
|
click.echo(
|
||||||
+ f'{create_count} created, {skipped_count} skipped.')
|
f"Processing the {total_count} dataset {dataset.id}. "
|
||||||
|
+ f"{create_count} created, {skipped_count} skipped."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
click.echo("Create dataset vdb index: {}".format(dataset.id))
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
if dataset.index_struct_dict['type'] == vector_type:
|
if dataset.index_struct_dict["type"] == vector_type:
|
||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
continue
|
continue
|
||||||
collection_name = ''
|
collection_name = ""
|
||||||
if vector_type == VectorType.WEAVIATE:
|
if vector_type == VectorType.WEAVIATE:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
|
||||||
"type": VectorType.WEAVIATE,
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == VectorType.QDRANT:
|
elif vector_type == VectorType.QDRANT:
|
||||||
if dataset.collection_binding_id:
|
if dataset.collection_binding_id:
|
||||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
dataset_collection_binding = (
|
||||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
db.session.query(DatasetCollectionBinding)
|
||||||
one_or_none()
|
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
if dataset_collection_binding:
|
if dataset_collection_binding:
|
||||||
collection_name = dataset_collection_binding.collection_name
|
collection_name = dataset_collection_binding.collection_name
|
||||||
else:
|
else:
|
||||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
raise ValueError("Dataset Collection Bindings is not exist!")
|
||||||
else:
|
else:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
|
||||||
"type": VectorType.QDRANT,
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
|
||||||
elif vector_type == VectorType.MILVUS:
|
elif vector_type == VectorType.MILVUS:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
|
||||||
"type": VectorType.MILVUS,
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == VectorType.RELYT:
|
elif vector_type == VectorType.RELYT:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
|
||||||
"type": 'relyt',
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == VectorType.TENCENT:
|
elif vector_type == VectorType.TENCENT:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
|
||||||
"type": VectorType.TENCENT,
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == VectorType.PGVECTOR:
|
elif vector_type == VectorType.PGVECTOR:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
|
||||||
"type": VectorType.PGVECTOR,
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == VectorType.OPENSEARCH:
|
elif vector_type == VectorType.OPENSEARCH:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
"type": VectorType.OPENSEARCH,
|
"type": VectorType.OPENSEARCH,
|
||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name},
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == VectorType.ANALYTICDB:
|
elif vector_type == VectorType.ANALYTICDB:
|
||||||
@ -341,9 +340,14 @@ def migrate_knowledge_vector_database():
|
|||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
"type": VectorType.ANALYTICDB,
|
"type": VectorType.ANALYTICDB,
|
||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name},
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
elif vector_type == VectorType.ELASTICSEARCH:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||||
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
|
|
||||||
@ -353,29 +357,41 @@ def migrate_knowledge_vector_database():
|
|||||||
try:
|
try:
|
||||||
vector.delete()
|
vector.delete()
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
|
click.style(
|
||||||
fg='green'))
|
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
|
click.style(
|
||||||
fg='red'))
|
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
|
||||||
|
)
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
dataset_documents = (
|
||||||
|
db.session.query(DatasetDocument)
|
||||||
|
.filter(
|
||||||
DatasetDocument.dataset_id == dataset.id,
|
DatasetDocument.dataset_id == dataset.id,
|
||||||
DatasetDocument.indexing_status == 'completed',
|
DatasetDocument.indexing_status == "completed",
|
||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
DatasetDocument.archived == False,
|
DatasetDocument.archived == False,
|
||||||
).all()
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
segments_count = 0
|
segments_count = 0
|
||||||
for dataset_document in dataset_documents:
|
for dataset_document in dataset_documents:
|
||||||
segments = db.session.query(DocumentSegment).filter(
|
segments = (
|
||||||
|
db.session.query(DocumentSegment)
|
||||||
|
.filter(
|
||||||
DocumentSegment.document_id == dataset_document.id,
|
DocumentSegment.document_id == dataset_document.id,
|
||||||
DocumentSegment.status == 'completed',
|
DocumentSegment.status == "completed",
|
||||||
DocumentSegment.enabled == True
|
DocumentSegment.enabled == True,
|
||||||
).all()
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
document = Document(
|
document = Document(
|
||||||
@ -385,7 +401,7 @@ def migrate_knowledge_vector_database():
|
|||||||
"doc_hash": segment.index_node_hash,
|
"doc_hash": segment.index_node_hash,
|
||||||
"document_id": segment.document_id,
|
"document_id": segment.document_id,
|
||||||
"dataset_id": segment.dataset_id,
|
"dataset_id": segment.dataset_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
@ -393,37 +409,43 @@ def migrate_knowledge_vector_database():
|
|||||||
|
|
||||||
if documents:
|
if documents:
|
||||||
try:
|
try:
|
||||||
click.echo(click.style(
|
click.echo(
|
||||||
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
|
click.style(
|
||||||
fg='green'))
|
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
vector.create(documents)
|
vector.create(documents)
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
|
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
|
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
|
||||||
raise e
|
raise e
|
||||||
db.session.add(dataset)
|
db.session.add(dataset)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(f'Successfully migrated dataset {dataset.id}.')
|
click.echo(f"Successfully migrated dataset {dataset.id}.")
|
||||||
create_count += 1
|
create_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
|
||||||
fg='red'))
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
click.echo(
|
click.echo(
|
||||||
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
|
click.style(
|
||||||
fg='green'))
|
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
|
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
|
||||||
def convert_to_agent_apps():
|
def convert_to_agent_apps():
|
||||||
"""
|
"""
|
||||||
Convert Agent Assistant to Agent App.
|
Convert Agent Assistant to Agent App.
|
||||||
"""
|
"""
|
||||||
click.echo(click.style('Start convert to agent apps.', fg='green'))
|
click.echo(click.style("Start convert to agent apps.", fg="green"))
|
||||||
|
|
||||||
proceeded_app_ids = []
|
proceeded_app_ids = []
|
||||||
|
|
||||||
@ -458,7 +480,7 @@ def convert_to_agent_apps():
|
|||||||
break
|
break
|
||||||
|
|
||||||
for app in apps:
|
for app in apps:
|
||||||
click.echo('Converting app: {}'.format(app.id))
|
click.echo("Converting app: {}".format(app.id))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.mode = AppMode.AGENT_CHAT.value
|
app.mode = AppMode.AGENT_CHAT.value
|
||||||
@ -470,137 +492,139 @@ def convert_to_agent_apps():
|
|||||||
)
|
)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
|
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(
|
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
|
||||||
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
|
|
||||||
str(e)), fg='red'))
|
|
||||||
|
|
||||||
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
|
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
|
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
|
||||||
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
|
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
|
||||||
def add_qdrant_doc_id_index(field: str):
|
def add_qdrant_doc_id_index(field: str):
|
||||||
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
|
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
|
||||||
vector_type = dify_config.VECTOR_STORE
|
vector_type = dify_config.VECTOR_STORE
|
||||||
if vector_type != "qdrant":
|
if vector_type != "qdrant":
|
||||||
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
|
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
|
||||||
return
|
return
|
||||||
create_count = 0
|
create_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bindings = db.session.query(DatasetCollectionBinding).all()
|
bindings = db.session.query(DatasetCollectionBinding).all()
|
||||||
if not bindings:
|
if not bindings:
|
||||||
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
|
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
|
||||||
return
|
return
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from qdrant_client.http.models import PayloadSchemaType
|
from qdrant_client.http.models import PayloadSchemaType
|
||||||
|
|
||||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
|
||||||
|
|
||||||
for binding in bindings:
|
for binding in bindings:
|
||||||
if dify_config.QDRANT_URL is None:
|
if dify_config.QDRANT_URL is None:
|
||||||
raise ValueError('Qdrant url is required.')
|
raise ValueError("Qdrant url is required.")
|
||||||
qdrant_config = QdrantConfig(
|
qdrant_config = QdrantConfig(
|
||||||
endpoint=dify_config.QDRANT_URL,
|
endpoint=dify_config.QDRANT_URL,
|
||||||
api_key=dify_config.QDRANT_API_KEY,
|
api_key=dify_config.QDRANT_API_KEY,
|
||||||
root_path=current_app.root_path,
|
root_path=current_app.root_path,
|
||||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
|
||||||
# create payload index
|
# create payload index
|
||||||
client.create_payload_index(binding.collection_name, field,
|
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
|
||||||
field_schema=PayloadSchemaType.KEYWORD)
|
|
||||||
create_count += 1
|
create_count += 1
|
||||||
except UnexpectedResponse as e:
|
except UnexpectedResponse as e:
|
||||||
# Collection does not exist, so return
|
# Collection does not exist, so return
|
||||||
if e.status_code == 404:
|
if e.status_code == 404:
|
||||||
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
|
click.echo(
|
||||||
|
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
# Some other error occurred, so re-raise the exception
|
# Some other error occurred, so re-raise the exception
|
||||||
else:
|
else:
|
||||||
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(click.style('Failed to create qdrant client.', fg='red'))
|
click.echo(click.style("Failed to create qdrant client.", fg="red"))
|
||||||
|
|
||||||
click.echo(
|
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
|
||||||
click.style(f'Congratulations! Create {create_count} collection indexes.',
|
|
||||||
fg='green'))
|
|
||||||
|
|
||||||
|
|
||||||
@click.command('create-tenant', help='Create account and tenant.')
|
@click.command("create-tenant", help="Create account and tenant.")
|
||||||
@click.option('--email', prompt=True, help='The email address of the tenant account.')
|
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||||
@click.option('--language', prompt=True, help='Account language, default: en-US.')
|
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||||
def create_tenant(email: str, language: Optional[str] = None):
|
def create_tenant(email: str, language: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Create tenant account
|
Create tenant account
|
||||||
"""
|
"""
|
||||||
if not email:
|
if not email:
|
||||||
click.echo(click.style('Sorry, email is required.', fg='red'))
|
click.echo(click.style("Sorry, email is required.", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create account
|
# Create account
|
||||||
email = email.strip()
|
email = email.strip()
|
||||||
|
|
||||||
if '@' not in email:
|
if "@" not in email:
|
||||||
click.echo(click.style('Sorry, invalid email address.', fg='red'))
|
click.echo(click.style("Sorry, invalid email address.", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
account_name = email.split('@')[0]
|
account_name = email.split("@")[0]
|
||||||
|
|
||||||
if language not in languages:
|
if language not in languages:
|
||||||
language = 'en-US'
|
language = "en-US"
|
||||||
|
|
||||||
# generate random password
|
# generate random password
|
||||||
new_password = secrets.token_urlsafe(16)
|
new_password = secrets.token_urlsafe(16)
|
||||||
|
|
||||||
# register account
|
# register account
|
||||||
account = RegisterService.register(
|
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||||
email=email,
|
|
||||||
name=account_name,
|
|
||||||
password=new_password,
|
|
||||||
language=language
|
|
||||||
)
|
|
||||||
|
|
||||||
TenantService.create_owner_tenant_if_not_exist(account)
|
TenantService.create_owner_tenant_if_not_exist(account)
|
||||||
|
|
||||||
click.echo(click.style('Congratulations! Account and tenant created.\n'
|
click.echo(
|
||||||
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
|
click.style(
|
||||||
|
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@click.command('upgrade-db', help='upgrade the database')
|
@click.command("upgrade-db", help="upgrade the database")
|
||||||
def upgrade_db():
|
def upgrade_db():
|
||||||
click.echo('Preparing database migration...')
|
click.echo("Preparing database migration...")
|
||||||
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
|
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||||
if lock.acquire(blocking=False):
|
if lock.acquire(blocking=False):
|
||||||
try:
|
try:
|
||||||
click.echo(click.style('Start database migration.', fg='green'))
|
click.echo(click.style("Start database migration.", fg="green"))
|
||||||
|
|
||||||
# run db migration
|
# run db migration
|
||||||
import flask_migrate
|
import flask_migrate
|
||||||
|
|
||||||
flask_migrate.upgrade()
|
flask_migrate.upgrade()
|
||||||
|
|
||||||
click.echo(click.style('Database migration successful!', fg='green'))
|
click.echo(click.style("Database migration successful!", fg="green"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f'Database migration failed, error: {e}')
|
logging.exception(f"Database migration failed, error: {e}")
|
||||||
finally:
|
finally:
|
||||||
lock.release()
|
lock.release()
|
||||||
else:
|
else:
|
||||||
click.echo('Database migration skipped')
|
click.echo("Database migration skipped")
|
||||||
|
|
||||||
|
|
||||||
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
|
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
|
||||||
def fix_app_site_missing():
|
def fix_app_site_missing():
|
||||||
"""
|
"""
|
||||||
Fix app related site missing issue.
|
Fix app related site missing issue.
|
||||||
"""
|
"""
|
||||||
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
|
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
|
||||||
|
|
||||||
failed_app_ids = []
|
failed_app_ids = []
|
||||||
while True:
|
while True:
|
||||||
@ -631,15 +655,14 @@ where sites.id is null limit 1000"""
|
|||||||
app_was_created.send(app, account=account)
|
app_was_created.send(app, account=account)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_app_ids.append(app_id)
|
failed_app_ids.append(app_id)
|
||||||
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
|
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
|
||||||
logging.exception(f'Fix app related site missing issue failed, error: {e}')
|
logging.exception(f"Fix app related site missing issue failed, error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not processed_count:
|
if not processed_count:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
|
||||||
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
|
|
||||||
|
|
||||||
|
|
||||||
def register_commands(app):
|
def register_commands(app):
|
||||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||||||
|
|
||||||
CURRENT_VERSION: str = Field(
|
CURRENT_VERSION: str = Field(
|
||||||
description='Dify version',
|
description='Dify version',
|
||||||
default='0.6.16',
|
default='0.7.0',
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMIT_SHA: str = Field(
|
COMMIT_SHA: str = Field(
|
||||||
|
@ -1 +1 @@
|
|||||||
HIDDEN_VALUE = '[__HIDDEN__]'
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
language_timezone_mapping = {
|
language_timezone_mapping = {
|
||||||
'en-US': 'America/New_York',
|
"en-US": "America/New_York",
|
||||||
'zh-Hans': 'Asia/Shanghai',
|
"zh-Hans": "Asia/Shanghai",
|
||||||
'zh-Hant': 'Asia/Taipei',
|
"zh-Hant": "Asia/Taipei",
|
||||||
'pt-BR': 'America/Sao_Paulo',
|
"pt-BR": "America/Sao_Paulo",
|
||||||
'es-ES': 'Europe/Madrid',
|
"es-ES": "Europe/Madrid",
|
||||||
'fr-FR': 'Europe/Paris',
|
"fr-FR": "Europe/Paris",
|
||||||
'de-DE': 'Europe/Berlin',
|
"de-DE": "Europe/Berlin",
|
||||||
'ja-JP': 'Asia/Tokyo',
|
"ja-JP": "Asia/Tokyo",
|
||||||
'ko-KR': 'Asia/Seoul',
|
"ko-KR": "Asia/Seoul",
|
||||||
'ru-RU': 'Europe/Moscow',
|
"ru-RU": "Europe/Moscow",
|
||||||
'it-IT': 'Europe/Rome',
|
"it-IT": "Europe/Rome",
|
||||||
'uk-UA': 'Europe/Kyiv',
|
"uk-UA": "Europe/Kyiv",
|
||||||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
"vi-VN": "Asia/Ho_Chi_Minh",
|
||||||
'ro-RO': 'Europe/Bucharest',
|
"ro-RO": "Europe/Bucharest",
|
||||||
'pl-PL': 'Europe/Warsaw',
|
"pl-PL": "Europe/Warsaw",
|
||||||
'hi-IN': 'Asia/Kolkata',
|
"hi-IN": "Asia/Kolkata",
|
||||||
'tr-TR': 'Europe/Istanbul',
|
"tr-TR": "Europe/Istanbul",
|
||||||
'fa-IR': 'Asia/Tehran',
|
"fa-IR": "Asia/Tehran",
|
||||||
}
|
}
|
||||||
|
|
||||||
languages = list(language_timezone_mapping.keys())
|
languages = list(language_timezone_mapping.keys())
|
||||||
@ -26,6 +26,5 @@ def supported_language(lang):
|
|||||||
if lang in languages:
|
if lang in languages:
|
||||||
return lang
|
return lang
|
||||||
|
|
||||||
error = ('{lang} is not a valid language.'
|
error = "{lang} is not a valid language.".format(lang=lang)
|
||||||
.format(lang=lang))
|
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
@ -5,82 +5,79 @@ from models.model import AppMode
|
|||||||
default_app_templates = {
|
default_app_templates = {
|
||||||
# workflow default mode
|
# workflow default mode
|
||||||
AppMode.WORKFLOW: {
|
AppMode.WORKFLOW: {
|
||||||
'app': {
|
"app": {
|
||||||
'mode': AppMode.WORKFLOW.value,
|
"mode": AppMode.WORKFLOW.value,
|
||||||
'enable_site': True,
|
"enable_site": True,
|
||||||
'enable_api': True
|
"enable_api": True,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
# completion default mode
|
# completion default mode
|
||||||
AppMode.COMPLETION: {
|
AppMode.COMPLETION: {
|
||||||
'app': {
|
"app": {
|
||||||
'mode': AppMode.COMPLETION.value,
|
"mode": AppMode.COMPLETION.value,
|
||||||
'enable_site': True,
|
"enable_site": True,
|
||||||
'enable_api': True
|
"enable_api": True,
|
||||||
},
|
},
|
||||||
'model_config': {
|
"model_config": {
|
||||||
'model': {
|
"model": {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"name": "gpt-4o",
|
"name": "gpt-4o",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"completion_params": {}
|
"completion_params": {},
|
||||||
},
|
},
|
||||||
'user_input_form': json.dumps([
|
"user_input_form": json.dumps(
|
||||||
|
[
|
||||||
{
|
{
|
||||||
"paragraph": {
|
"paragraph": {
|
||||||
"label": "Query",
|
"label": "Query",
|
||||||
"variable": "query",
|
"variable": "query",
|
||||||
"required": True,
|
"required": True,
|
||||||
"default": ""
|
"default": "",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]),
|
]
|
||||||
'pre_prompt': '{{query}}'
|
),
|
||||||
|
"pre_prompt": "{{query}}",
|
||||||
},
|
},
|
||||||
|
|
||||||
},
|
},
|
||||||
|
|
||||||
# chat default mode
|
# chat default mode
|
||||||
AppMode.CHAT: {
|
AppMode.CHAT: {
|
||||||
'app': {
|
"app": {
|
||||||
'mode': AppMode.CHAT.value,
|
"mode": AppMode.CHAT.value,
|
||||||
'enable_site': True,
|
"enable_site": True,
|
||||||
'enable_api': True
|
"enable_api": True,
|
||||||
},
|
},
|
||||||
'model_config': {
|
"model_config": {
|
||||||
'model': {
|
"model": {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"name": "gpt-4o",
|
"name": "gpt-4o",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"completion_params": {}
|
"completion_params": {},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
# advanced-chat default mode
|
# advanced-chat default mode
|
||||||
AppMode.ADVANCED_CHAT: {
|
AppMode.ADVANCED_CHAT: {
|
||||||
'app': {
|
"app": {
|
||||||
'mode': AppMode.ADVANCED_CHAT.value,
|
"mode": AppMode.ADVANCED_CHAT.value,
|
||||||
'enable_site': True,
|
"enable_site": True,
|
||||||
'enable_api': True
|
"enable_api": True,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
# agent-chat default mode
|
# agent-chat default mode
|
||||||
AppMode.AGENT_CHAT: {
|
AppMode.AGENT_CHAT: {
|
||||||
'app': {
|
"app": {
|
||||||
'mode': AppMode.AGENT_CHAT.value,
|
"mode": AppMode.AGENT_CHAT.value,
|
||||||
'enable_site': True,
|
"enable_site": True,
|
||||||
'enable_api': True
|
"enable_api": True,
|
||||||
},
|
},
|
||||||
'model_config': {
|
"model_config": {
|
||||||
'model': {
|
"model": {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"name": "gpt-4o",
|
"name": "gpt-4o",
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"completion_params": {}
|
"completion_params": {},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar('tenant_id')
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||||
|
|
||||||
|
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
||||||
|
@ -33,7 +33,7 @@ class CompletionConversationApi(Resource):
|
|||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_pagination_fields)
|
@marshal_with(conversation_pagination_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument('keyword', type=str, location='args')
|
parser.add_argument('keyword', type=str, location='args')
|
||||||
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_message_detail_fields)
|
@marshal_with(conversation_message_detail_fields)
|
||||||
def get(self, app_model, conversation_id):
|
def get(self, app_model, conversation_id):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
@ -256,7 +256,7 @@ class ChatConversationDetailApi(Resource):
|
|||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
|
@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
|
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
|
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
@ -178,11 +178,20 @@ class DatasetDocumentListApi(Resource):
|
|||||||
.subquery()
|
.subquery()
|
||||||
|
|
||||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
|
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
|
||||||
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)))
|
.order_by(
|
||||||
|
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||||
|
sort_logic(Document.position),
|
||||||
|
)
|
||||||
elif sort == 'created_at':
|
elif sort == 'created_at':
|
||||||
query = query.order_by(sort_logic(Document.created_at))
|
query = query.order_by(
|
||||||
|
sort_logic(Document.created_at),
|
||||||
|
sort_logic(Document.position),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
query = query.order_by(desc(Document.created_at))
|
query = query.order_by(
|
||||||
|
desc(Document.created_at),
|
||||||
|
desc(Document.position),
|
||||||
|
)
|
||||||
|
|
||||||
paginated_documents = query.paginate(
|
paginated_documents = query.paginate(
|
||||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||||
|
@ -93,6 +93,7 @@ class DatasetConfigManager:
|
|||||||
reranking_model=dataset_configs.get('reranking_model'),
|
reranking_model=dataset_configs.get('reranking_model'),
|
||||||
weights=dataset_configs.get('weights'),
|
weights=dataset_configs.get('weights'),
|
||||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||||
|
rerank_mode=dataset_configs["reranking_mode"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,8 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
'queue_manager': queue_manager,
|
'queue_manager': queue_manager,
|
||||||
'conversation_id': conversation.id,
|
'conversation_id': conversation.id,
|
||||||
'message_id': message.id,
|
'message_id': message.id,
|
||||||
'user': user,
|
'context': contextvars.copy_context(),
|
||||||
'context': contextvars.copy_context()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
@ -246,7 +245,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
user=user,
|
user=user,
|
||||||
stream=stream
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||||
@ -259,7 +258,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user: Account,
|
|
||||||
context: contextvars.Context) -> None:
|
context: contextvars.Context) -> None:
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
@ -307,14 +305,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
def _handle_advanced_chat_response(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool = False) \
|
stream: bool = False,
|
||||||
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
Handle response.
|
Handle response.
|
||||||
:param application_generate_entity: application generate entity
|
:param application_generate_entity: application generate entity
|
||||||
@ -334,7 +335,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
user=user,
|
user=user,
|
||||||
stream=stream
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -3,9 +3,6 @@ import os
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
@ -44,7 +44,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
|||||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -76,7 +76,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
stream: bool
|
stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||||
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
SystemVariable.QUERY: message.query,
|
SystemVariable.QUERY: message.query,
|
||||||
SystemVariable.FILES: application_generate_entity.files,
|
SystemVariable.FILES: application_generate_entity.files,
|
||||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||||
SystemVariable.USER_ID: user_id
|
SystemVariable.USER_ID: user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
|||||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
|||||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||||
from models.model import App, AppMode, Message, MessageAnnotation
|
from models.model import App, AppMode, Message, MessageAnnotation
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
class AppRunner:
|
class AppRunner:
|
||||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
query: Optional[str] = None) -> int:
|
query: Optional[str] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Get pre calculate rest tokens
|
Get pre calculate rest tokens
|
||||||
@ -126,7 +128,7 @@ class AppRunner:
|
|||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
memory: Optional[TokenBufferMemory] = None) \
|
memory: Optional[TokenBufferMemory] = None) \
|
||||||
|
@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
return introduction
|
return introduction
|
||||||
|
|
||||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
def _get_conversation(self, conversation_id: str):
|
||||||
"""
|
"""
|
||||||
Get conversation by conversation id
|
Get conversation by conversation id
|
||||||
:param conversation_id: conversation id
|
:param conversation_id: conversation id
|
||||||
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
raise ConversationNotExistsError()
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
def _get_message(self, message_id: str) -> Message:
|
def _get_message(self, message_id: str) -> Message:
|
||||||
|
@ -11,7 +11,8 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
from core.workflow.enums import SystemVariable
|
||||||
|
from core.workflow.entities.node_entities import UserFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
|
|||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
|
from core.workflow.nodes.end.end_node import EndNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
|
|||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
FileSegment,
|
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
@ -13,11 +12,9 @@ from .segments import (
|
|||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
from .variables import (
|
from .variables import (
|
||||||
ArrayAnyVariable,
|
ArrayAnyVariable,
|
||||||
ArrayFileVariable,
|
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
FileVariable,
|
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
NoneVariable,
|
NoneVariable,
|
||||||
@ -32,7 +29,6 @@ __all__ = [
|
|||||||
'FloatVariable',
|
'FloatVariable',
|
||||||
'ObjectVariable',
|
'ObjectVariable',
|
||||||
'SecretVariable',
|
'SecretVariable',
|
||||||
'FileVariable',
|
|
||||||
'StringVariable',
|
'StringVariable',
|
||||||
'ArrayAnyVariable',
|
'ArrayAnyVariable',
|
||||||
'Variable',
|
'Variable',
|
||||||
@ -45,11 +41,9 @@ __all__ = [
|
|||||||
'FloatSegment',
|
'FloatSegment',
|
||||||
'ObjectSegment',
|
'ObjectSegment',
|
||||||
'ArrayAnySegment',
|
'ArrayAnySegment',
|
||||||
'FileSegment',
|
|
||||||
'StringSegment',
|
'StringSegment',
|
||||||
'ArrayStringVariable',
|
'ArrayStringVariable',
|
||||||
'ArrayNumberVariable',
|
'ArrayNumberVariable',
|
||||||
'ArrayObjectVariable',
|
'ArrayObjectVariable',
|
||||||
'ArrayFileVariable',
|
|
||||||
'ArraySegment',
|
'ArraySegment',
|
||||||
]
|
]
|
||||||
|
@ -2,12 +2,10 @@ from collections.abc import Mapping
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
|
|
||||||
from .exc import VariableError
|
from .exc import VariableError
|
||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
FileSegment,
|
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
@ -17,11 +15,9 @@ from .segments import (
|
|||||||
)
|
)
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
from .variables import (
|
from .variables import (
|
||||||
ArrayFileVariable,
|
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
FileVariable,
|
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
ObjectVariable,
|
ObjectVariable,
|
||||||
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
|||||||
result = FloatVariable.model_validate(mapping)
|
result = FloatVariable.model_validate(mapping)
|
||||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||||
raise VariableError(f'invalid number value {value}')
|
raise VariableError(f'invalid number value {value}')
|
||||||
case SegmentType.FILE:
|
|
||||||
result = FileVariable.model_validate(mapping)
|
|
||||||
case SegmentType.OBJECT if isinstance(value, dict):
|
case SegmentType.OBJECT if isinstance(value, dict):
|
||||||
result = ObjectVariable.model_validate(mapping)
|
result = ObjectVariable.model_validate(mapping)
|
||||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||||
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
|||||||
result = ArrayNumberVariable.model_validate(mapping)
|
result = ArrayNumberVariable.model_validate(mapping)
|
||||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||||
result = ArrayObjectVariable.model_validate(mapping)
|
result = ArrayObjectVariable.model_validate(mapping)
|
||||||
case SegmentType.ARRAY_FILE if isinstance(value, list):
|
|
||||||
mapping = dict(mapping)
|
|
||||||
mapping['value'] = [{'value': v} for v in value]
|
|
||||||
result = ArrayFileVariable.model_validate(mapping)
|
|
||||||
case _:
|
case _:
|
||||||
raise VariableError(f'not supported value type {value_type}')
|
raise VariableError(f'not supported value type {value_type}')
|
||||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||||
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
|||||||
return ObjectSegment(value=value)
|
return ObjectSegment(value=value)
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return ArrayAnySegment(value=value)
|
return ArrayAnySegment(value=value)
|
||||||
if isinstance(value, FileVar):
|
|
||||||
return FileSegment(value=value)
|
|
||||||
raise ValueError(f'not supported value {value}')
|
raise ValueError(f'not supported value {value}')
|
||||||
|
@ -5,8 +5,6 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
|
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
|
|||||||
value: int
|
value: int
|
||||||
|
|
||||||
|
|
||||||
class FileSegment(Segment):
|
|
||||||
value_type: SegmentType = SegmentType.FILE
|
|
||||||
# TODO: embed FileVar in this model.
|
|
||||||
value: FileVar
|
|
||||||
|
|
||||||
@property
|
|
||||||
def markdown(self) -> str:
|
|
||||||
return self.value.to_markdown()
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectSegment(Segment):
|
class ObjectSegment(Segment):
|
||||||
@ -108,7 +99,13 @@ class ObjectSegment(Segment):
|
|||||||
class ArraySegment(Segment):
|
class ArraySegment(Segment):
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
items = []
|
||||||
|
for item in self.value:
|
||||||
|
if hasattr(item, 'to_markdown'):
|
||||||
|
items.append(item.to_markdown())
|
||||||
|
else:
|
||||||
|
items.append(str(item))
|
||||||
|
return '\n'.join(items)
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
@ -130,7 +127,3 @@ class ArrayObjectSegment(ArraySegment):
|
|||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]]
|
value: Sequence[Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileSegment(ArraySegment):
|
|
||||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
|
||||||
value: Sequence[FileSegment]
|
|
||||||
|
@ -10,8 +10,6 @@ class SegmentType(str, Enum):
|
|||||||
ARRAY_STRING = 'array[string]'
|
ARRAY_STRING = 'array[string]'
|
||||||
ARRAY_NUMBER = 'array[number]'
|
ARRAY_NUMBER = 'array[number]'
|
||||||
ARRAY_OBJECT = 'array[object]'
|
ARRAY_OBJECT = 'array[object]'
|
||||||
ARRAY_FILE = 'array[file]'
|
|
||||||
OBJECT = 'object'
|
OBJECT = 'object'
|
||||||
FILE = 'file'
|
|
||||||
|
|
||||||
GROUP = 'group'
|
GROUP = 'group'
|
||||||
|
@ -4,11 +4,9 @@ from core.helper import encrypter
|
|||||||
|
|
||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
ArrayFileSegment,
|
|
||||||
ArrayNumberSegment,
|
ArrayNumberSegment,
|
||||||
ArrayObjectSegment,
|
ArrayObjectSegment,
|
||||||
ArrayStringSegment,
|
ArrayStringSegment,
|
||||||
FileSegment,
|
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FileVariable(FileSegment, Variable):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectVariable(ObjectSegment, Variable):
|
class ObjectVariable(ObjectSegment, Variable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileVariable(ArrayFileSegment, Variable):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SecretVariable(StringVariable):
|
class SecretVariable(StringVariable):
|
||||||
value_type: SegmentType = SegmentType.SECRET
|
value_type: SegmentType = SegmentType.SECRET
|
||||||
|
@ -99,7 +99,7 @@ class MessageFileParser:
|
|||||||
# return all file objs
|
# return all file objs
|
||||||
return new_files
|
return new_files
|
||||||
|
|
||||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
|
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||||
"""
|
"""
|
||||||
transform message files
|
transform message files
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ class MessageFileParser:
|
|||||||
|
|
||||||
return type_file_objs
|
return type_file_objs
|
||||||
|
|
||||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
|
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||||
"""
|
"""
|
||||||
transform file to file obj
|
transform file to file obj
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from core.model_runtime.entities.model_entities import DefaultParameterName
|
from core.model_runtime.entities.model_entities import DefaultParameterName
|
||||||
|
|
||||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||||
@ -94,5 +93,16 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||||||
},
|
},
|
||||||
'required': False,
|
'required': False,
|
||||||
'options': ['JSON', 'XML'],
|
'options': ['JSON', 'XML'],
|
||||||
}
|
},
|
||||||
|
DefaultParameterName.JSON_SCHEMA: {
|
||||||
|
'label': {
|
||||||
|
'en_US': 'JSON Schema',
|
||||||
|
},
|
||||||
|
'type': 'text',
|
||||||
|
'help': {
|
||||||
|
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
|
||||||
|
'zh_Hans': '设置返回的json schema,llm将按照它返回',
|
||||||
|
},
|
||||||
|
'required': False,
|
||||||
|
},
|
||||||
}
|
}
|
@ -95,6 +95,7 @@ class DefaultParameterName(Enum):
|
|||||||
FREQUENCY_PENALTY = "frequency_penalty"
|
FREQUENCY_PENALTY = "frequency_penalty"
|
||||||
MAX_TOKENS = "max_tokens"
|
MAX_TOKENS = "max_tokens"
|
||||||
RESPONSE_FORMAT = "response_format"
|
RESPONSE_FORMAT = "response_format"
|
||||||
|
JSON_SCHEMA = "json_schema"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||||
@ -118,6 +119,7 @@ class ParameterType(Enum):
|
|||||||
INT = "int"
|
INT = "int"
|
||||||
STRING = "string"
|
STRING = "string"
|
||||||
BOOLEAN = "boolean"
|
BOOLEAN = "boolean"
|
||||||
|
TEXT = "text"
|
||||||
|
|
||||||
|
|
||||||
class ModelPropertyKey(Enum):
|
class ModelPropertyKey(Enum):
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
- gpt-4o
|
- gpt-4o
|
||||||
- gpt-4o-2024-05-13
|
- gpt-4o-2024-05-13
|
||||||
- gpt-4o-2024-08-06
|
- gpt-4o-2024-08-06
|
||||||
|
- chatgpt-4o-latest
|
||||||
- gpt-4o-mini
|
- gpt-4o-mini
|
||||||
- gpt-4o-mini-2024-07-18
|
- gpt-4o-mini-2024-07-18
|
||||||
- gpt-4-turbo
|
- gpt-4-turbo
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
model: chatgpt-4o-latest
|
||||||
|
label:
|
||||||
|
zh_Hans: chatgpt-4o-latest
|
||||||
|
en_US: chatgpt-4o-latest
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 512
|
||||||
|
min: 1
|
||||||
|
max: 16384
|
||||||
|
- name: response_format
|
||||||
|
label:
|
||||||
|
zh_Hans: 回复格式
|
||||||
|
en_US: response_format
|
||||||
|
type: string
|
||||||
|
help:
|
||||||
|
zh_Hans: 指定模型必须输出的格式
|
||||||
|
en_US: specifying the format that the model must output
|
||||||
|
required: false
|
||||||
|
options:
|
||||||
|
- text
|
||||||
|
- json_object
|
||||||
|
pricing:
|
||||||
|
input: '2.50'
|
||||||
|
output: '10.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
@ -37,6 +37,9 @@ parameter_rules:
|
|||||||
options:
|
options:
|
||||||
- text
|
- text
|
||||||
- json_object
|
- json_object
|
||||||
|
- json_schema
|
||||||
|
- name: json_schema
|
||||||
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '2.50'
|
input: '2.50'
|
||||||
output: '10.00'
|
output: '10.00'
|
||||||
|
@ -37,6 +37,9 @@ parameter_rules:
|
|||||||
options:
|
options:
|
||||||
- text
|
- text
|
||||||
- json_object
|
- json_object
|
||||||
|
- json_schema
|
||||||
|
- name: json_schema
|
||||||
|
use_template: json_schema
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.15'
|
input: '0.15'
|
||||||
output: '0.60'
|
output: '0.60'
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
@ -544,13 +545,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
response_format = model_parameters.get("response_format")
|
response_format = model_parameters.get("response_format")
|
||||||
if response_format:
|
if response_format:
|
||||||
if response_format == "json_object":
|
if response_format == "json_schema":
|
||||||
response_format = {"type": "json_object"}
|
json_schema = model_parameters.get("json_schema")
|
||||||
|
if not json_schema:
|
||||||
|
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||||
|
try:
|
||||||
|
schema = json.loads(json_schema)
|
||||||
|
except:
|
||||||
|
raise ValueError(f"not currect json_schema format: {json_schema}")
|
||||||
|
model_parameters.pop("json_schema")
|
||||||
|
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||||
else:
|
else:
|
||||||
response_format = {"type": "text"}
|
model_parameters["response_format"] = {"type": response_format}
|
||||||
|
|
||||||
model_parameters["response_format"] = response_format
|
|
||||||
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
@ -922,11 +928,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
|
||||||
if model.startswith('ft:'):
|
if model.startswith('ft:'):
|
||||||
model = model.split(':')[1]
|
model = model.split(':')[1]
|
||||||
|
|
||||||
|
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
||||||
|
if model == "chatgpt-4o-latest":
|
||||||
|
model = "gpt-4o"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -946,7 +955,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"get_num_tokens_from_messages() is not presently implemented "
|
f"get_num_tokens_from_messages() is not presently implemented "
|
||||||
f"for model {model}."
|
f"for model {model}."
|
||||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
"See https://platform.openai.com/docs/advanced-usage/managing-tokens for "
|
||||||
"information on how messages are converted to tokens."
|
"information on how messages are converted to tokens."
|
||||||
)
|
)
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
@ -0,0 +1,61 @@
|
|||||||
|
model: Llama3-Chinese_v2
|
||||||
|
label:
|
||||||
|
en_US: Llama3-Chinese_v2
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,61 @@
|
|||||||
|
model: Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||||
|
label:
|
||||||
|
en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 1024
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,61 @@
|
|||||||
|
model: Meta-Llama-3-8B-Instruct
|
||||||
|
label:
|
||||||
|
en_US: Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,61 @@
|
|||||||
|
model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||||
|
label:
|
||||||
|
en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 410960
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,61 @@
|
|||||||
|
model: Meta-Llama-3.1-8B-Instruct
|
||||||
|
label:
|
||||||
|
en_US: Meta-Llama-3.1-8B-Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 4096
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.1
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -55,7 +55,8 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.000'
|
input: "0.000"
|
||||||
output: '0.000'
|
output: "0.000"
|
||||||
unit: '0.000'
|
unit: "0.000"
|
||||||
currency: RMB
|
currency: RMB
|
||||||
|
deprecated: true
|
||||||
|
@ -55,7 +55,8 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.000'
|
input: "0.000"
|
||||||
output: '0.000'
|
output: "0.000"
|
||||||
unit: '0.000'
|
unit: "0.000"
|
||||||
currency: RMB
|
currency: RMB
|
||||||
|
deprecated: true
|
||||||
|
@ -6,7 +6,7 @@ features:
|
|||||||
- agent-thought
|
- agent-thought
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 8192
|
context_size: 2048
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
@ -55,7 +55,7 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.000'
|
input: "0.000"
|
||||||
output: '0.000'
|
output: "0.000"
|
||||||
unit: '0.000'
|
unit: "0.000"
|
||||||
currency: RMB
|
currency: RMB
|
||||||
|
@ -6,7 +6,7 @@ features:
|
|||||||
- agent-thought
|
- agent-thought
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: completion
|
mode: completion
|
||||||
context_size: 8192
|
context_size: 32768
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
@ -55,7 +55,7 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.000'
|
input: "0.000"
|
||||||
output: '0.000'
|
output: "0.000"
|
||||||
unit: '0.000'
|
unit: "0.000"
|
||||||
currency: RMB
|
currency: RMB
|
||||||
|
@ -8,12 +8,12 @@ features:
|
|||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 8192
|
context_size: 2048
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
type: float
|
type: float
|
||||||
default: 0.3
|
default: 0.7
|
||||||
min: 0.0
|
min: 0.0
|
||||||
max: 2.0
|
max: 2.0
|
||||||
help:
|
help:
|
||||||
@ -57,7 +57,7 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.000'
|
input: "0.000"
|
||||||
output: '0.000'
|
output: "0.000"
|
||||||
unit: '0.000'
|
unit: "0.000"
|
||||||
currency: RMB
|
currency: RMB
|
||||||
|
@ -0,0 +1,61 @@
|
|||||||
|
model: Qwen2-72B-Instruct
|
||||||
|
label:
|
||||||
|
en_US: Qwen2-72B-Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 131072
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -8,7 +8,7 @@ features:
|
|||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: completion
|
mode: completion
|
||||||
context_size: 8192
|
context_size: 32768
|
||||||
parameter_rules:
|
parameter_rules:
|
||||||
- name: temperature
|
- name: temperature
|
||||||
use_template: temperature
|
use_template: temperature
|
||||||
@ -57,7 +57,7 @@ parameter_rules:
|
|||||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.000'
|
input: "0.000"
|
||||||
output: '0.000'
|
output: "0.000"
|
||||||
unit: '0.000'
|
unit: "0.000"
|
||||||
currency: RMB
|
currency: RMB
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
|
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
|
||||||
|
- Meta-Llama-3.1-8B-Instruct
|
||||||
|
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
|
||||||
|
- Meta-Llama-3-8B-Instruct
|
||||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||||
|
- Qwen2-72B-Instruct
|
||||||
- Qwen2-7B
|
- Qwen2-7B
|
||||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
- Qwen-14B-Chat-Int4
|
||||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||||
- Qwen1.5-7B
|
- Qwen1.5-7B
|
||||||
- Qwen-14B-Chat-Int4
|
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||||
|
- deepseek-v2-chat
|
||||||
|
- deepseek-v2-lite-chat
|
||||||
|
- Llama3-Chinese_v2
|
||||||
|
- chatglm3-6b
|
||||||
|
@ -0,0 +1,61 @@
|
|||||||
|
model: chatglm3-6b
|
||||||
|
label:
|
||||||
|
en_US: chatglm3-6b
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 8192
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,61 @@
|
|||||||
|
model: deepseek-v2-chat
|
||||||
|
label:
|
||||||
|
en_US: deepseek-v2-chat
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 4096
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,61 @@
|
|||||||
|
model: deepseek-v2-lite-chat
|
||||||
|
label:
|
||||||
|
en_US: deepseek-v2-lite-chat
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 2048
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.5
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 600
|
||||||
|
min: 1
|
||||||
|
max: 1248
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
pricing:
|
||||||
|
input: "0.000"
|
||||||
|
output: "0.000"
|
||||||
|
unit: "0.000"
|
||||||
|
currency: RMB
|
@ -0,0 +1,4 @@
|
|||||||
|
model: BAAI/bge-large-en-v1.5
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 32768
|
@ -0,0 +1,4 @@
|
|||||||
|
model: BAAI/bge-large-zh-v1.5
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 32768
|
@ -0,0 +1,81 @@
|
|||||||
|
model: farui-plus
|
||||||
|
label:
|
||||||
|
en_US: farui-plus
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 12288
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
type: float
|
||||||
|
default: 0.3
|
||||||
|
min: 0.0
|
||||||
|
max: 2.0
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||||
|
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
type: int
|
||||||
|
default: 2000
|
||||||
|
min: 1
|
||||||
|
max: 2000
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||||
|
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
type: float
|
||||||
|
default: 0.8
|
||||||
|
min: 0.1
|
||||||
|
max: 0.9
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||||
|
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||||
|
- name: top_k
|
||||||
|
type: int
|
||||||
|
min: 0
|
||||||
|
max: 99
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||||
|
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||||
|
- name: seed
|
||||||
|
required: false
|
||||||
|
type: int
|
||||||
|
default: 1234
|
||||||
|
label:
|
||||||
|
zh_Hans: 随机种子
|
||||||
|
en_US: Random seed
|
||||||
|
help:
|
||||||
|
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||||
|
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||||
|
- name: repetition_penalty
|
||||||
|
required: false
|
||||||
|
type: float
|
||||||
|
default: 1.1
|
||||||
|
label:
|
||||||
|
en_US: Repetition penalty
|
||||||
|
help:
|
||||||
|
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||||
|
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||||
|
- name: enable_search
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
help:
|
||||||
|
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||||
|
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.02'
|
||||||
|
output: '0.02'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: RMB
|
@ -159,6 +159,8 @@ You should also complete the text started with ``` but not tell ``` directly.
|
|||||||
"""
|
"""
|
||||||
if model in ['qwen-turbo-chat', 'qwen-plus-chat']:
|
if model in ['qwen-turbo-chat', 'qwen-plus-chat']:
|
||||||
model = model.replace('-chat', '')
|
model = model.replace('-chat', '')
|
||||||
|
if model == 'farui-plus':
|
||||||
|
model = 'qwen-farui-plus'
|
||||||
|
|
||||||
if model in self.tokenizers:
|
if model in self.tokenizers:
|
||||||
tokenizer = self.tokenizers[model]
|
tokenizer = self.tokenizers[model]
|
||||||
|
@ -1 +1 @@
|
|||||||
- soloar-1-mini-chat
|
- solar-1-mini-chat
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import PromptTemplateEntity
|
from core.app.app_config.entities import PromptTemplateEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
|
|||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
class ModelMode(enum.Enum):
|
class ModelMode(enum.Enum):
|
||||||
COMPLETION = 'completion'
|
COMPLETION = 'completion'
|
||||||
@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||||
@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) \
|
model_config: ModelConfigWithCredentialsEntity) \
|
||||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) \
|
model_config: ModelConfigWithCredentialsEntity) \
|
||||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
|
|
||||||
return [self.get_last_user_message(prompt, files)], stops
|
return [self.get_last_user_message(prompt, files)], stops
|
||||||
|
|
||||||
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
|
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||||
for file in files:
|
for file in files:
|
||||||
|
@ -0,0 +1,191 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from flask import current_app
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchConfig(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: str
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
@model_validator(mode='before')
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values['host']:
|
||||||
|
raise ValueError("config HOST is required")
|
||||||
|
if not values['port']:
|
||||||
|
raise ValueError("config PORT is required")
|
||||||
|
if not values['username']:
|
||||||
|
raise ValueError("config USERNAME is required")
|
||||||
|
if not values['password']:
|
||||||
|
raise ValueError("config PASSWORD is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchVector(BaseVector):
|
||||||
|
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
|
||||||
|
super().__init__(index_name.lower())
|
||||||
|
self._client = self._init_client(config)
|
||||||
|
self._attributes = attributes
|
||||||
|
|
||||||
|
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||||
|
try:
|
||||||
|
client = Elasticsearch(
|
||||||
|
hosts=f'{config.host}:{config.port}',
|
||||||
|
basic_auth=(config.username, config.password),
|
||||||
|
request_timeout=100000,
|
||||||
|
retry_on_timeout=True,
|
||||||
|
max_retries=10000,
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
raise ConnectionError("Vector database connection error")
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return 'elasticsearch'
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
uuids = self._get_uuids(documents)
|
||||||
|
texts = [d.page_content for d in documents]
|
||||||
|
metadatas = [d.metadata for d in documents]
|
||||||
|
|
||||||
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
mapping = {
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "text"
|
||||||
|
},
|
||||||
|
"vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"index": True,
|
||||||
|
"dims": dim,
|
||||||
|
"similarity": "l2_norm"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._client.indices.create(index=self._collection_name, mappings=mapping)
|
||||||
|
|
||||||
|
added_ids = []
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
self._client.index(index=self._collection_name,
|
||||||
|
id=uuids[i],
|
||||||
|
document={
|
||||||
|
"text": text,
|
||||||
|
"vector": embeddings[i] if embeddings[i] else None,
|
||||||
|
"metadata": metadatas[i] if metadatas[i] else {},
|
||||||
|
})
|
||||||
|
added_ids.append(uuids[i])
|
||||||
|
|
||||||
|
self._client.indices.refresh(index=self._collection_name)
|
||||||
|
return uuids
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
return self._client.exists(index=self._collection_name, id=id).__bool__()
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
for id in ids:
|
||||||
|
self._client.delete(index=self._collection_name, id=id)
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
query_str = {
|
||||||
|
'query': {
|
||||||
|
'match': {
|
||||||
|
f'metadata.{key}': f'{value}'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results = self._client.search(index=self._collection_name, body=query_str)
|
||||||
|
ids = [hit['_id'] for hit in results['hits']['hits']]
|
||||||
|
if ids:
|
||||||
|
self.delete_by_ids(ids)
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
self._client.indices.delete(index=self._collection_name)
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
query_str = {
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {
|
||||||
|
"match_all": {}
|
||||||
|
},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
|
||||||
|
"params": {
|
||||||
|
"query_vector": query_vector
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = self._client.search(index=self._collection_name, body=query_str)
|
||||||
|
|
||||||
|
docs_and_scores = []
|
||||||
|
for hit in results['hits']['hits']:
|
||||||
|
docs_and_scores.append(
|
||||||
|
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for doc, score in docs_and_scores:
|
||||||
|
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
|
||||||
|
if score > score_threshold:
|
||||||
|
doc.metadata['score'] = score
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
# Sort the documents by score in descending order
|
||||||
|
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
query_str = {
|
||||||
|
"match": {
|
||||||
|
"text": query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results = self._client.search(index=self._collection_name, query=query_str)
|
||||||
|
docs = []
|
||||||
|
for hit in results['hits']['hits']:
|
||||||
|
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
return self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return ElasticSearchVector(
|
||||||
|
index_name=collection_name,
|
||||||
|
config=ElasticSearchConfig(
|
||||||
|
host=config.get('ELASTICSEARCH_HOST'),
|
||||||
|
port=config.get('ELASTICSEARCH_PORT'),
|
||||||
|
username=config.get('ELASTICSEARCH_USERNAME'),
|
||||||
|
password=config.get('ELASTICSEARCH_PASSWORD'),
|
||||||
|
),
|
||||||
|
attributes=[]
|
||||||
|
)
|
@ -93,7 +93,7 @@ class MyScaleVector(BaseVector):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def escape_str(value: Any) -> str:
|
def escape_str(value: Any) -> str:
|
||||||
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
|
return "".join(" " if c in ("\\", "'") else c for c in str(value))
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||||
@ -118,7 +118,7 @@ class MyScaleVector(BaseVector):
|
|||||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
|
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||||
|
|
||||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||||
top_k = kwargs.get("top_k", 5)
|
top_k = kwargs.get("top_k", 5)
|
||||||
|
@ -71,6 +71,9 @@ class Vector:
|
|||||||
case VectorType.RELYT:
|
case VectorType.RELYT:
|
||||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||||
return RelytVectorFactory
|
return RelytVectorFactory
|
||||||
|
case VectorType.ELASTICSEARCH:
|
||||||
|
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
return ElasticSearchVectorFactory
|
||||||
case VectorType.TIDB_VECTOR:
|
case VectorType.TIDB_VECTOR:
|
||||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||||
return TiDBVectorFactory
|
return TiDBVectorFactory
|
||||||
|
@ -15,3 +15,4 @@ class VectorType(str, Enum):
|
|||||||
OPENSEARCH = 'opensearch'
|
OPENSEARCH = 'opensearch'
|
||||||
TENCENT = 'tencent'
|
TENCENT = 'tencent'
|
||||||
ORACLE = 'oracle'
|
ORACLE = 'oracle'
|
||||||
|
ELASTICSEARCH = 'elasticsearch'
|
||||||
|
@ -103,8 +103,8 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
plain text, image url or link url
|
plain text, image url or link url
|
||||||
"""
|
"""
|
||||||
message: Union[str, bytes, dict] = None
|
message: str | bytes | dict | None = None
|
||||||
meta: dict[str, Any] = None
|
meta: dict[str, Any] | None = None
|
||||||
save_as: str = ''
|
save_as: str = ''
|
||||||
|
|
||||||
class ToolInvokeMessageBinary(BaseModel):
|
class ToolInvokeMessageBinary(BaseModel):
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||||
|
<svg width="24" height="25" viewBox="0 0 24 25" xmlns="http://www.w3.org/2000/svg" fill="none"><path fill="#FC6D26" d="M14.975 8.904L14.19 6.55l-1.552-4.67a.268.268 0 00-.255-.18.268.268 0 00-.254.18l-1.552 4.667H5.422L3.87 1.879a.267.267 0 00-.254-.179.267.267 0 00-.254.18l-1.55 4.667-.784 2.357a.515.515 0 00.193.583l6.78 4.812 6.778-4.812a.516.516 0 00.196-.583z"/><path fill="#E24329" d="M8 14.296l2.578-7.75H5.423L8 14.296z"/><path fill="#FC6D26" d="M8 14.296l-2.579-7.75H1.813L8 14.296z"/><path fill="#FCA326" d="M1.81 6.549l-.784 2.354a.515.515 0 00.193.583L8 14.3 1.81 6.55z"/><path fill="#E24329" d="M1.812 6.549h3.612L3.87 1.882a.268.268 0 00-.254-.18.268.268 0 00-.255.18L1.812 6.549z"/><path fill="#FC6D26" d="M8 14.296l2.578-7.75h3.614L8 14.296z"/><path fill="#FCA326" d="M14.19 6.549l.783 2.354a.514.514 0 01-.193.583L8 14.296l6.188-7.747h.001z"/><path fill="#E24329" d="M14.19 6.549H10.58l1.551-4.667a.267.267 0 01.255-.18c.115 0 .217.073.254.18l1.552 4.667z"/></svg>
|
After Width: | Height: | Size: 1.1 KiB |
34
api/core/tools/provider/builtin/gitlab/gitlab.py
Normal file
34
api/core/tools/provider/builtin/gitlab/gitlab.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
|
||||||
|
|
||||||
|
class GitlabProvider(BuiltinToolProviderController):
|
||||||
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
|
||||||
|
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.")
|
||||||
|
|
||||||
|
if 'site_url' not in credentials or not credentials.get('site_url'):
|
||||||
|
site_url = 'https://gitlab.com'
|
||||||
|
else:
|
||||||
|
site_url = credentials.get('site_url')
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/vnd.text+json",
|
||||||
|
"Authorization": f"Bearer {credentials.get('access_tokens')}",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
url= f"{site_url}/api/v4/user",
|
||||||
|
headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ToolProviderCredentialValidationError((response.json()).get('message'))
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e))
|
38
api/core/tools/provider/builtin/gitlab/gitlab.yaml
Normal file
38
api/core/tools/provider/builtin/gitlab/gitlab.yaml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
identity:
|
||||||
|
author: Leo.Wang
|
||||||
|
name: gitlab
|
||||||
|
label:
|
||||||
|
en_US: Gitlab
|
||||||
|
zh_Hans: Gitlab
|
||||||
|
description:
|
||||||
|
en_US: Gitlab plugin for commit
|
||||||
|
zh_Hans: 用于获取Gitlab commit的插件
|
||||||
|
icon: gitlab.svg
|
||||||
|
credentials_for_provider:
|
||||||
|
access_tokens:
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Gitlab access token
|
||||||
|
zh_Hans: Gitlab access token
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your Gitlab access token
|
||||||
|
zh_Hans: 请输入你的 Gitlab access token
|
||||||
|
help:
|
||||||
|
en_US: Get your Gitlab access token from Gitlab
|
||||||
|
zh_Hans: 从 Gitlab 获取您的 access token
|
||||||
|
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
|
||||||
|
site_url:
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
default: 'https://gitlab.com'
|
||||||
|
label:
|
||||||
|
en_US: Gitlab site url
|
||||||
|
zh_Hans: Gitlab site url
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your Gitlab site url
|
||||||
|
zh_Hans: 请输入你的 Gitlab site url
|
||||||
|
help:
|
||||||
|
en_US: Find your Gitlab url
|
||||||
|
zh_Hans: 找到你的Gitlab url
|
||||||
|
url: https://gitlab.com/help
|
101
api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py
Normal file
101
api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class GitlabCommitsTool(BuiltinTool):
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any]
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
|
||||||
|
project = tool_parameters.get('project', '')
|
||||||
|
employee = tool_parameters.get('employee', '')
|
||||||
|
start_time = tool_parameters.get('start_time', '')
|
||||||
|
end_time = tool_parameters.get('end_time', '')
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
return self.create_text_message('Project is required')
|
||||||
|
|
||||||
|
if not start_time:
|
||||||
|
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
|
||||||
|
if not end_time:
|
||||||
|
end_time = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
access_token = self.runtime.credentials.get('access_tokens')
|
||||||
|
site_url = self.runtime.credentials.get('site_url')
|
||||||
|
|
||||||
|
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
|
||||||
|
return self.create_text_message("Gitlab API Access Tokens is required.")
|
||||||
|
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
|
||||||
|
site_url = 'https://gitlab.com'
|
||||||
|
|
||||||
|
# Get commit content
|
||||||
|
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
|
||||||
|
|
||||||
|
return self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||||||
|
|
||||||
|
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
|
||||||
|
domain = site_url
|
||||||
|
headers = {"PRIVATE-TOKEN": access_token}
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all of projects
|
||||||
|
url = f"{domain}/api/v4/projects"
|
||||||
|
response = requests.get(url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
projects = response.json()
|
||||||
|
|
||||||
|
filtered_projects = [p for p in projects if project == "*" or p['name'] == project]
|
||||||
|
|
||||||
|
for project in filtered_projects:
|
||||||
|
project_id = project['id']
|
||||||
|
project_name = project['name']
|
||||||
|
print(f"Project: {project_name}")
|
||||||
|
|
||||||
|
# Get all of proejct commits
|
||||||
|
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||||
|
params = {
|
||||||
|
'since': start_time,
|
||||||
|
'until': end_time
|
||||||
|
}
|
||||||
|
if employee:
|
||||||
|
params['author'] = employee
|
||||||
|
|
||||||
|
commits_response = requests.get(commits_url, headers=headers, params=params)
|
||||||
|
commits_response.raise_for_status()
|
||||||
|
commits = commits_response.json()
|
||||||
|
|
||||||
|
for commit in commits:
|
||||||
|
commit_sha = commit['id']
|
||||||
|
print(f"\tCommit SHA: {commit_sha}")
|
||||||
|
|
||||||
|
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||||
|
diff_response = requests.get(diff_url, headers=headers)
|
||||||
|
diff_response.raise_for_status()
|
||||||
|
diffs = diff_response.json()
|
||||||
|
|
||||||
|
for diff in diffs:
|
||||||
|
# Caculate code lines of changed
|
||||||
|
added_lines = diff['diff'].count('\n+')
|
||||||
|
removed_lines = diff['diff'].count('\n-')
|
||||||
|
total_changes = added_lines + removed_lines
|
||||||
|
|
||||||
|
if total_changes > 1:
|
||||||
|
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
|
||||||
|
results.append({
|
||||||
|
"project": project_name,
|
||||||
|
"commit_sha": commit_sha,
|
||||||
|
"diff": final_code
|
||||||
|
})
|
||||||
|
print(f"Commit code:{final_code}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Error fetching data from GitLab: {e}")
|
||||||
|
|
||||||
|
return results
|
@ -0,0 +1,56 @@
|
|||||||
|
identity:
|
||||||
|
name: gitlab_commits
|
||||||
|
author: Leo.Wang
|
||||||
|
label:
|
||||||
|
en_US: Gitlab Commits
|
||||||
|
zh_Hans: Gitlab代码提交内容
|
||||||
|
description:
|
||||||
|
human:
|
||||||
|
en_US: A tool for query gitlab commits. Input should be a exists username.
|
||||||
|
zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。
|
||||||
|
llm: A tool for query gitlab commits. Input should be a exists username or project.
|
||||||
|
parameters:
|
||||||
|
- name: employee
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: employee
|
||||||
|
zh_Hans: 员工用户名
|
||||||
|
human_description:
|
||||||
|
en_US: employee
|
||||||
|
zh_Hans: 员工用户名
|
||||||
|
llm_description: employee for gitlab
|
||||||
|
form: llm
|
||||||
|
- name: project
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: project
|
||||||
|
zh_Hans: 项目名
|
||||||
|
human_description:
|
||||||
|
en_US: project
|
||||||
|
zh_Hans: 项目名
|
||||||
|
llm_description: project for gitlab
|
||||||
|
form: llm
|
||||||
|
- name: start_time
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: start_time
|
||||||
|
zh_Hans: 开始时间
|
||||||
|
human_description:
|
||||||
|
en_US: start_time
|
||||||
|
zh_Hans: 开始时间
|
||||||
|
llm_description: start_time for gitlab
|
||||||
|
form: llm
|
||||||
|
- name: end_time
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: end_time
|
||||||
|
zh_Hans: 结束时间
|
||||||
|
human_description:
|
||||||
|
en_US: end_time
|
||||||
|
zh_Hans: 结束时间
|
||||||
|
llm_description: end_time for gitlab
|
||||||
|
form: llm
|
@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
from pydantic_core.core_schema import ValidationInfo
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolDescription,
|
ToolDescription,
|
||||||
ToolIdentity,
|
ToolIdentity,
|
||||||
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
|
|||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel, ABC):
|
class Tool(BaseModel, ABC):
|
||||||
identity: Optional[ToolIdentity] = None
|
identity: Optional[ToolIdentity] = None
|
||||||
@ -290,7 +292,7 @@ class Tool(BaseModel, ABC):
|
|||||||
message=image,
|
message=image,
|
||||||
save_as=save_as)
|
save_as=save_as)
|
||||||
|
|
||||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||||
message='',
|
message='',
|
||||||
meta={
|
meta={
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
from core.file.file_obj import FileTransferMethod, FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ class ToolFileMessageTransformer:
|
|||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
))
|
))
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||||
file_var: FileVar = message.meta.get('file_var')
|
file_var = message.meta.get('file_var')
|
||||||
if file_var:
|
if file_var:
|
||||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||||
|
@ -4,13 +4,14 @@ from typing import Any, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class NodeType(Enum):
|
class NodeType(Enum):
|
||||||
"""
|
"""
|
||||||
Node Types.
|
Node Types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
START = 'start'
|
START = 'start'
|
||||||
END = 'end'
|
END = 'end'
|
||||||
ANSWER = 'answer'
|
ANSWER = 'answer'
|
||||||
@ -44,33 +45,11 @@ class NodeType(Enum):
|
|||||||
raise ValueError(f'invalid node type value {value}')
|
raise ValueError(f'invalid node type value {value}')
|
||||||
|
|
||||||
|
|
||||||
class SystemVariable(Enum):
|
|
||||||
"""
|
|
||||||
System Variables.
|
|
||||||
"""
|
|
||||||
QUERY = 'query'
|
|
||||||
FILES = 'files'
|
|
||||||
CONVERSATION_ID = 'conversation_id'
|
|
||||||
USER_ID = 'user_id'
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> 'SystemVariable':
|
|
||||||
"""
|
|
||||||
Get value of given system variable.
|
|
||||||
|
|
||||||
:param value: system variable value
|
|
||||||
:return: system variable
|
|
||||||
"""
|
|
||||||
for system_variable in cls:
|
|
||||||
if system_variable.value == value:
|
|
||||||
return system_variable
|
|
||||||
raise ValueError(f'invalid system variable value {value}')
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRunMetadataKey(Enum):
|
class NodeRunMetadataKey(Enum):
|
||||||
"""
|
"""
|
||||||
Node Run Metadata Key.
|
Node Run Metadata Key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TOTAL_TOKENS = 'total_tokens'
|
TOTAL_TOKENS = 'total_tokens'
|
||||||
TOTAL_PRICE = 'total_price'
|
TOTAL_PRICE = 'total_price'
|
||||||
CURRENCY = 'currency'
|
CURRENCY = 'currency'
|
||||||
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Node Run Result.
|
Node Run Result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
|
||||||
inputs: Optional[dict[str, Any]] = None # node inputs
|
inputs: Optional[dict[str, Any]] = None # node inputs
|
||||||
|
@ -7,7 +7,7 @@ from typing_extensions import deprecated
|
|||||||
|
|
||||||
from core.app.segments import Segment, Variable, factory
|
from core.app.segments import Segment, Variable, factory
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.enums import SystemVariable
|
||||||
|
|
||||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||||
|
|
||||||
|
25
api/core/workflow/enums.py
Normal file
25
api/core/workflow/enums.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariable(str, Enum):
|
||||||
|
"""
|
||||||
|
System Variables.
|
||||||
|
"""
|
||||||
|
QUERY = 'query'
|
||||||
|
FILES = 'files'
|
||||||
|
CONVERSATION_ID = 'conversation_id'
|
||||||
|
USER_ID = 'user_id'
|
||||||
|
DIALOGUE_COUNT = 'dialogue_count'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str):
|
||||||
|
"""
|
||||||
|
Get value of given system variable.
|
||||||
|
|
||||||
|
:param value: system variable value
|
||||||
|
:return: system variable
|
||||||
|
"""
|
||||||
|
for system_variable in cls:
|
||||||
|
if system_variable.value == value:
|
||||||
|
return system_variable
|
||||||
|
raise ValueError(f'invalid system variable value {value}')
|
@ -140,9 +140,6 @@ class HttpRequestNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
files = []
|
files = []
|
||||||
mimetype, file_binary = response.extract_file()
|
mimetype, file_binary = response.extract_file()
|
||||||
# if not image, return directly
|
|
||||||
if 'image' not in mimetype:
|
|
||||||
return files
|
|
||||||
|
|
||||||
if mimetype:
|
if mimetype:
|
||||||
# extract filename from url
|
# extract filename from url
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -9,7 +9,6 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
@ -24,9 +23,10 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.llm.entities import (
|
from core.workflow.nodes.llm.entities import (
|
||||||
@ -41,6 +41,10 @@ from models.model import Conversation
|
|||||||
from models.provider import Provider, ProviderType
|
from models.provider import Provider, ProviderType
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInvokeCompleted(BaseModel):
|
class ModelInvokeCompleted(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -81,7 +85,7 @@ class LLMNode(BaseNode):
|
|||||||
node_inputs = {}
|
node_inputs = {}
|
||||||
|
|
||||||
# fetch files
|
# fetch files
|
||||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
files = self._fetch_files(node_data, variable_pool)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||||
@ -368,7 +372,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return inputs # type: ignore
|
return inputs # type: ignore
|
||||||
|
|
||||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
|
||||||
"""
|
"""
|
||||||
Fetch files
|
Fetch files
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
@ -563,7 +567,7 @@ class LLMNode(BaseNode):
|
|||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
query_prompt_template: Optional[str],
|
query_prompt_template: Optional[str],
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) \
|
model_config: ModelConfigWithCredentialsEntity) \
|
||||||
|
@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
|
|||||||
from os import path
|
from os import path
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.segments import parser
|
from core.app.segments import ArrayAnyVariable, parser
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariable
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
@ -140,9 +141,9 @@ class ToolNode(BaseNode):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||||
# FIXME: ensure this is a ArrayVariable contains FileVariable.
|
|
||||||
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
|
||||||
return [file_var.value for file_var in variable.value] if variable else []
|
assert isinstance(variable, ArrayAnyVariable)
|
||||||
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
|
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
|
||||||
-> tuple[str, list[FileVar], list[dict]]:
|
-> tuple[str, list[FileVar], list[dict]]:
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from blinker import signal
|
from blinker import signal
|
||||||
|
|
||||||
# sender: app
|
# sender: app
|
||||||
app_was_created = signal('app-was-created')
|
app_was_created = signal("app-was-created")
|
||||||
|
|
||||||
# sender: app, kwargs: app_model_config
|
# sender: app, kwargs: app_model_config
|
||||||
app_model_config_was_updated = signal('app-model-config-was-updated')
|
app_model_config_was_updated = signal("app-model-config-was-updated")
|
||||||
|
|
||||||
# sender: app, kwargs: published_workflow
|
# sender: app, kwargs: published_workflow
|
||||||
app_published_workflow_was_updated = signal('app-published-workflow-was-updated')
|
app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
|
||||||
|
|
||||||
# sender: app, kwargs: synced_draft_workflow
|
# sender: app, kwargs: synced_draft_workflow
|
||||||
app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced')
|
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from blinker import signal
|
from blinker import signal
|
||||||
|
|
||||||
# sender: dataset
|
# sender: dataset
|
||||||
dataset_was_deleted = signal('dataset-was-deleted')
|
dataset_was_deleted = signal("dataset-was-deleted")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from blinker import signal
|
from blinker import signal
|
||||||
|
|
||||||
# sender: document
|
# sender: document
|
||||||
document_was_deleted = signal('document-was-deleted')
|
document_was_deleted = signal("document-was-deleted")
|
||||||
|
@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task
|
|||||||
@dataset_was_deleted.connect
|
@dataset_was_deleted.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
dataset = sender
|
dataset = sender
|
||||||
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
|
clean_dataset_task.delay(
|
||||||
dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)
|
dataset.id,
|
||||||
|
dataset.tenant_id,
|
||||||
|
dataset.indexing_technique,
|
||||||
|
dataset.index_struct,
|
||||||
|
dataset.collection_binding_id,
|
||||||
|
dataset.doc_form,
|
||||||
|
)
|
||||||
|
@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task
|
|||||||
@document_was_deleted.connect
|
@document_was_deleted.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
document_id = sender
|
document_id = sender
|
||||||
dataset_id = kwargs.get('dataset_id')
|
dataset_id = kwargs.get("dataset_id")
|
||||||
doc_form = kwargs.get('doc_form')
|
doc_form = kwargs.get("doc_form")
|
||||||
file_id = kwargs.get('file_id')
|
file_id = kwargs.get("file_id")
|
||||||
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
||||||
|
@ -14,21 +14,25 @@ from models.dataset import Document
|
|||||||
@document_index_created.connect
|
@document_index_created.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
dataset_id = sender
|
dataset_id = sender
|
||||||
document_ids = kwargs.get('document_ids', None)
|
document_ids = kwargs.get("document_ids", None)
|
||||||
documents = []
|
documents = []
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
for document_id in document_ids:
|
for document_id in document_ids:
|
||||||
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
|
||||||
|
|
||||||
document = db.session.query(Document).filter(
|
document = (
|
||||||
|
db.session.query(Document)
|
||||||
|
.filter(
|
||||||
Document.id == document_id,
|
Document.id == document_id,
|
||||||
Document.dataset_id == dataset_id
|
Document.dataset_id == dataset_id,
|
||||||
).first()
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound('Document not found')
|
raise NotFound("Document not found")
|
||||||
|
|
||||||
document.indexing_status = 'parsing'
|
document.indexing_status = "parsing"
|
||||||
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
@ -38,8 +42,8 @@ def handle(sender, **kwargs):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
indexing_runner.run(documents)
|
indexing_runner.run(documents)
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
|
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
|
||||||
except DocumentIsPausedException as ex:
|
except DocumentIsPausedException as ex:
|
||||||
logging.info(click.style(str(ex), fg='yellow'))
|
logging.info(click.style(str(ex), fg="yellow"))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -10,7 +10,7 @@ def handle(sender, **kwargs):
|
|||||||
installed_app = InstalledApp(
|
installed_app = InstalledApp(
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
app_owner_tenant_id=app.tenant_id
|
app_owner_tenant_id=app.tenant_id,
|
||||||
)
|
)
|
||||||
db.session.add(installed_app)
|
db.session.add(installed_app)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -7,15 +7,15 @@ from models.model import Site
|
|||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
"""Create site record when an app is created."""
|
"""Create site record when an app is created."""
|
||||||
app = sender
|
app = sender
|
||||||
account = kwargs.get('account')
|
account = kwargs.get("account")
|
||||||
site = Site(
|
site = Site(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
title=app.name,
|
title=app.name,
|
||||||
icon=app.icon,
|
icon=app.icon,
|
||||||
icon_background=app.icon_background,
|
icon_background=app.icon_background,
|
||||||
default_language=account.interface_language,
|
default_language=account.interface_language,
|
||||||
customize_token_strategy='not_allow',
|
customize_token_strategy="not_allow",
|
||||||
code=Site.generate_code(16)
|
code=Site.generate_code(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(site)
|
db.session.add(site)
|
||||||
|
@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
|
|||||||
@message_was_created.connect
|
@message_was_created.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
message = sender
|
message = sender
|
||||||
application_generate_entity = kwargs.get('application_generate_entity')
|
application_generate_entity = kwargs.get("application_generate_entity")
|
||||||
|
|
||||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||||
return
|
return
|
||||||
@ -39,7 +39,7 @@ def handle(sender, **kwargs):
|
|||||||
elif quota_unit == QuotaUnit.CREDITS:
|
elif quota_unit == QuotaUnit.CREDITS:
|
||||||
used_quota = 1
|
used_quota = 1
|
||||||
|
|
||||||
if 'gpt-4' in model_config.model:
|
if "gpt-4" in model_config.model:
|
||||||
used_quota = 20
|
used_quota = 20
|
||||||
else:
|
else:
|
||||||
used_quota = 1
|
used_quota = 1
|
||||||
@ -50,6 +50,6 @@ def handle(sender, **kwargs):
|
|||||||
Provider.provider_name == model_config.provider,
|
Provider.provider_name == model_config.provider,
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||||
Provider.quota_limit > Provider.quota_used
|
Provider.quota_limit > Provider.quota_used,
|
||||||
).update({'quota_used': Provider.quota_used + used_quota})
|
).update({"quota_used": Provider.quota_used + used_quota})
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced
|
|||||||
@app_draft_workflow_was_synced.connect
|
@app_draft_workflow_was_synced.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
app = sender
|
app = sender
|
||||||
for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []):
|
for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
|
||||||
if node_data.get('data', {}).get('type') == NodeType.TOOL.value:
|
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
|
||||||
try:
|
try:
|
||||||
tool_entity = ToolEntity(**node_data["data"])
|
tool_entity = ToolEntity(**node_data["data"])
|
||||||
tool_runtime = ToolManager.get_tool_runtime(
|
tool_runtime = ToolManager.get_tool_runtime(
|
||||||
@ -23,7 +23,7 @@ def handle(sender, **kwargs):
|
|||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=tool_entity.provider_name,
|
provider_name=tool_entity.provider_name,
|
||||||
provider_type=tool_entity.provider_type,
|
provider_type=tool_entity.provider_type,
|
||||||
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
|
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}',
|
||||||
)
|
)
|
||||||
manager.delete_tool_parameters_cache()
|
manager.delete_tool_parameters_cache()
|
||||||
except:
|
except:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from blinker import signal
|
from blinker import signal
|
||||||
|
|
||||||
# sender: document
|
# sender: document
|
||||||
document_index_created = signal('document-index-created')
|
document_index_created = signal("document-index-created")
|
||||||
|
@ -7,13 +7,11 @@ from models.model import AppModelConfig
|
|||||||
@app_model_config_was_updated.connect
|
@app_model_config_was_updated.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
app = sender
|
app = sender
|
||||||
app_model_config = kwargs.get('app_model_config')
|
app_model_config = kwargs.get("app_model_config")
|
||||||
|
|
||||||
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
||||||
|
|
||||||
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
|
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
|
||||||
AppDatasetJoin.app_id == app.id
|
|
||||||
).all()
|
|
||||||
|
|
||||||
removed_dataset_ids = []
|
removed_dataset_ids = []
|
||||||
if not app_dataset_joins:
|
if not app_dataset_joins:
|
||||||
@ -29,16 +27,12 @@ def handle(sender, **kwargs):
|
|||||||
if removed_dataset_ids:
|
if removed_dataset_ids:
|
||||||
for dataset_id in removed_dataset_ids:
|
for dataset_id in removed_dataset_ids:
|
||||||
db.session.query(AppDatasetJoin).filter(
|
db.session.query(AppDatasetJoin).filter(
|
||||||
AppDatasetJoin.app_id == app.id,
|
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||||
AppDatasetJoin.dataset_id == dataset_id
|
|
||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
if added_dataset_ids:
|
if added_dataset_ids:
|
||||||
for dataset_id in added_dataset_ids:
|
for dataset_id in added_dataset_ids:
|
||||||
app_dataset_join = AppDatasetJoin(
|
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
|
||||||
app_id=app.id,
|
|
||||||
dataset_id=dataset_id
|
|
||||||
)
|
|
||||||
db.session.add(app_dataset_join)
|
db.session.add(app_dataset_join)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
|
|||||||
|
|
||||||
agent_mode = app_model_config.agent_mode_dict
|
agent_mode = app_model_config.agent_mode_dict
|
||||||
|
|
||||||
tools = agent_mode.get('tools', []) or []
|
tools = agent_mode.get("tools", []) or []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if len(list(tool.keys())) != 1:
|
if len(list(tool.keys())) != 1:
|
||||||
continue
|
continue
|
||||||
@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
|
|||||||
|
|
||||||
# get dataset from dataset_configs
|
# get dataset from dataset_configs
|
||||||
dataset_configs = app_model_config.dataset_configs_dict
|
dataset_configs = app_model_config.dataset_configs_dict
|
||||||
datasets = dataset_configs.get('datasets', {}) or {}
|
datasets = dataset_configs.get("datasets", {}) or {}
|
||||||
for dataset in datasets.get('datasets', []) or []:
|
for dataset in datasets.get("datasets", []) or []:
|
||||||
keys = list(dataset.keys())
|
keys = list(dataset.keys())
|
||||||
if len(keys) == 1 and keys[0] == 'dataset':
|
if len(keys) == 1 and keys[0] == "dataset":
|
||||||
if dataset['dataset'].get('id'):
|
if dataset["dataset"].get("id"):
|
||||||
dataset_ids.add(dataset['dataset'].get('id'))
|
dataset_ids.add(dataset["dataset"].get("id"))
|
||||||
|
|
||||||
return dataset_ids
|
return dataset_ids
|
||||||
|
@ -11,13 +11,11 @@ from models.workflow import Workflow
|
|||||||
@app_published_workflow_was_updated.connect
|
@app_published_workflow_was_updated.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
app = sender
|
app = sender
|
||||||
published_workflow = kwargs.get('published_workflow')
|
published_workflow = kwargs.get("published_workflow")
|
||||||
published_workflow = cast(Workflow, published_workflow)
|
published_workflow = cast(Workflow, published_workflow)
|
||||||
|
|
||||||
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
|
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
|
||||||
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
|
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
|
||||||
AppDatasetJoin.app_id == app.id
|
|
||||||
).all()
|
|
||||||
|
|
||||||
removed_dataset_ids = []
|
removed_dataset_ids = []
|
||||||
if not app_dataset_joins:
|
if not app_dataset_joins:
|
||||||
@ -33,16 +31,12 @@ def handle(sender, **kwargs):
|
|||||||
if removed_dataset_ids:
|
if removed_dataset_ids:
|
||||||
for dataset_id in removed_dataset_ids:
|
for dataset_id in removed_dataset_ids:
|
||||||
db.session.query(AppDatasetJoin).filter(
|
db.session.query(AppDatasetJoin).filter(
|
||||||
AppDatasetJoin.app_id == app.id,
|
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
|
||||||
AppDatasetJoin.dataset_id == dataset_id
|
|
||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
if added_dataset_ids:
|
if added_dataset_ids:
|
||||||
for dataset_id in added_dataset_ids:
|
for dataset_id in added_dataset_ids:
|
||||||
app_dataset_join = AppDatasetJoin(
|
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
|
||||||
app_id=app.id,
|
|
||||||
dataset_id=dataset_id
|
|
||||||
)
|
|
||||||
db.session.add(app_dataset_join)
|
db.session.add(app_dataset_join)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
|
|||||||
if not graph:
|
if not graph:
|
||||||
return dataset_ids
|
return dataset_ids
|
||||||
|
|
||||||
nodes = graph.get('nodes', [])
|
nodes = graph.get("nodes", [])
|
||||||
|
|
||||||
# fetch all knowledge retrieval nodes
|
# fetch all knowledge retrieval nodes
|
||||||
knowledge_retrieval_nodes = [node for node in nodes
|
knowledge_retrieval_nodes = [
|
||||||
if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value]
|
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
|
||||||
|
]
|
||||||
|
|
||||||
if not knowledge_retrieval_nodes:
|
if not knowledge_retrieval_nodes:
|
||||||
return dataset_ids
|
return dataset_ids
|
||||||
|
|
||||||
for node in knowledge_retrieval_nodes:
|
for node in knowledge_retrieval_nodes:
|
||||||
try:
|
try:
|
||||||
node_data = KnowledgeRetrievalNodeData(**node.get('data', {}))
|
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
|
||||||
dataset_ids.update(node_data.dataset_ids)
|
dataset_ids.update(node_data.dataset_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue
|
continue
|
||||||
|
@ -9,13 +9,13 @@ from models.provider import Provider
|
|||||||
@message_was_created.connect
|
@message_was_created.connect
|
||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
message = sender
|
message = sender
|
||||||
application_generate_entity = kwargs.get('application_generate_entity')
|
application_generate_entity = kwargs.get("application_generate_entity")
|
||||||
|
|
||||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||||
return
|
return
|
||||||
|
|
||||||
db.session.query(Provider).filter(
|
db.session.query(Provider).filter(
|
||||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||||
Provider.provider_name == application_generate_entity.model_conf.provider
|
Provider.provider_name == application_generate_entity.model_conf.provider,
|
||||||
).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)})
|
).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)})
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from blinker import signal
|
from blinker import signal
|
||||||
|
|
||||||
# sender: message, kwargs: conversation
|
# sender: message, kwargs: conversation
|
||||||
message_was_created = signal('message-was-created')
|
message_was_created = signal("message-was-created")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from blinker import signal
|
from blinker import signal
|
||||||
|
|
||||||
# sender: tenant
|
# sender: tenant
|
||||||
tenant_was_created = signal('tenant-was-created')
|
tenant_was_created = signal("tenant-was-created")
|
||||||
|
|
||||||
# sender: tenant
|
# sender: tenant
|
||||||
tenant_was_updated = signal('tenant-was-updated')
|
tenant_was_updated = signal("tenant-was-updated")
|
||||||
|
@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery:
|
|||||||
]
|
]
|
||||||
day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
|
day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
|
||||||
beat_schedule = {
|
beat_schedule = {
|
||||||
'clean_embedding_cache_task': {
|
"clean_embedding_cache_task": {
|
||||||
'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task',
|
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
|
||||||
'schedule': timedelta(days=day),
|
"schedule": timedelta(days=day),
|
||||||
|
},
|
||||||
|
"clean_unused_datasets_task": {
|
||||||
|
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
|
||||||
|
"schedule": timedelta(days=day),
|
||||||
},
|
},
|
||||||
'clean_unused_datasets_task': {
|
|
||||||
'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
|
|
||||||
'schedule': timedelta(days=day),
|
|
||||||
}
|
}
|
||||||
}
|
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||||
celery_app.conf.update(
|
|
||||||
beat_schedule=beat_schedule,
|
|
||||||
imports=imports
|
|
||||||
)
|
|
||||||
|
|
||||||
return celery_app
|
return celery_app
|
||||||
|
@ -2,15 +2,14 @@ from flask import Flask
|
|||||||
|
|
||||||
|
|
||||||
def init_app(app: Flask):
|
def init_app(app: Flask):
|
||||||
if app.config.get('API_COMPRESSION_ENABLED'):
|
if app.config.get("API_COMPRESSION_ENABLED"):
|
||||||
from flask_compress import Compress
|
from flask_compress import Compress
|
||||||
|
|
||||||
app.config['COMPRESS_MIMETYPES'] = [
|
app.config["COMPRESS_MIMETYPES"] = [
|
||||||
'application/json',
|
"application/json",
|
||||||
'image/svg+xml',
|
"image/svg+xml",
|
||||||
'text/html',
|
"text/html",
|
||||||
]
|
]
|
||||||
|
|
||||||
compress = Compress()
|
compress = Compress()
|
||||||
compress.init_app(app)
|
compress.init_app(app)
|
||||||
|
|
||||||
|
@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy
|
|||||||
from sqlalchemy import MetaData
|
from sqlalchemy import MetaData
|
||||||
|
|
||||||
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
||||||
'ix': '%(column_0_label)s_idx',
|
"ix": "%(column_0_label)s_idx",
|
||||||
'uq': '%(table_name)s_%(column_0_name)s_key',
|
"uq": "%(table_name)s_%(column_0_name)s_key",
|
||||||
'ck': '%(table_name)s_%(constraint_name)s_check',
|
"ck": "%(table_name)s_%(constraint_name)s_check",
|
||||||
'fk': '%(table_name)s_%(column_0_name)s_fkey',
|
"fk": "%(table_name)s_%(column_0_name)s_fkey",
|
||||||
'pk': '%(table_name)s_pkey',
|
"pk": "%(table_name)s_pkey",
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
|
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
|
||||||
|
@ -14,67 +14,69 @@ class Mail:
|
|||||||
return self._client is not None
|
return self._client is not None
|
||||||
|
|
||||||
def init_app(self, app: Flask):
|
def init_app(self, app: Flask):
|
||||||
if app.config.get('MAIL_TYPE'):
|
if app.config.get("MAIL_TYPE"):
|
||||||
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
|
if app.config.get("MAIL_DEFAULT_SEND_FROM"):
|
||||||
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
|
self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
|
||||||
|
|
||||||
if app.config.get('MAIL_TYPE') == 'resend':
|
if app.config.get("MAIL_TYPE") == "resend":
|
||||||
api_key = app.config.get('RESEND_API_KEY')
|
api_key = app.config.get("RESEND_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError('RESEND_API_KEY is not set')
|
raise ValueError("RESEND_API_KEY is not set")
|
||||||
|
|
||||||
api_url = app.config.get('RESEND_API_URL')
|
api_url = app.config.get("RESEND_API_URL")
|
||||||
if api_url:
|
if api_url:
|
||||||
resend.api_url = api_url
|
resend.api_url = api_url
|
||||||
|
|
||||||
resend.api_key = api_key
|
resend.api_key = api_key
|
||||||
self._client = resend.Emails
|
self._client = resend.Emails
|
||||||
elif app.config.get('MAIL_TYPE') == 'smtp':
|
elif app.config.get("MAIL_TYPE") == "smtp":
|
||||||
from libs.smtp import SMTPClient
|
from libs.smtp import SMTPClient
|
||||||
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
|
|
||||||
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
|
if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
|
||||||
if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'):
|
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
|
||||||
raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS')
|
if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
|
||||||
|
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
|
||||||
self._client = SMTPClient(
|
self._client = SMTPClient(
|
||||||
server=app.config.get('SMTP_SERVER'),
|
server=app.config.get("SMTP_SERVER"),
|
||||||
port=app.config.get('SMTP_PORT'),
|
port=app.config.get("SMTP_PORT"),
|
||||||
username=app.config.get('SMTP_USERNAME'),
|
username=app.config.get("SMTP_USERNAME"),
|
||||||
password=app.config.get('SMTP_PASSWORD'),
|
password=app.config.get("SMTP_PASSWORD"),
|
||||||
_from=app.config.get('MAIL_DEFAULT_SEND_FROM'),
|
_from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
|
||||||
use_tls=app.config.get('SMTP_USE_TLS'),
|
use_tls=app.config.get("SMTP_USE_TLS"),
|
||||||
opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS')
|
opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
|
raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
|
||||||
else:
|
else:
|
||||||
logging.warning('MAIL_TYPE is not set')
|
logging.warning("MAIL_TYPE is not set")
|
||||||
|
|
||||||
|
|
||||||
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
|
||||||
if not self._client:
|
if not self._client:
|
||||||
raise ValueError('Mail client is not initialized')
|
raise ValueError("Mail client is not initialized")
|
||||||
|
|
||||||
if not from_ and self._default_send_from:
|
if not from_ and self._default_send_from:
|
||||||
from_ = self._default_send_from
|
from_ = self._default_send_from
|
||||||
|
|
||||||
if not from_:
|
if not from_:
|
||||||
raise ValueError('mail from is not set')
|
raise ValueError("mail from is not set")
|
||||||
|
|
||||||
if not to:
|
if not to:
|
||||||
raise ValueError('mail to is not set')
|
raise ValueError("mail to is not set")
|
||||||
|
|
||||||
if not subject:
|
if not subject:
|
||||||
raise ValueError('mail subject is not set')
|
raise ValueError("mail subject is not set")
|
||||||
|
|
||||||
if not html:
|
if not html:
|
||||||
raise ValueError('mail html is not set')
|
raise ValueError("mail html is not set")
|
||||||
|
|
||||||
self._client.send({
|
self._client.send(
|
||||||
|
{
|
||||||
"from": from_,
|
"from": from_,
|
||||||
"to": to,
|
"to": to,
|
||||||
"subject": subject,
|
"subject": subject,
|
||||||
"html": html
|
"html": html,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: Flask):
|
def init_app(app: Flask):
|
||||||
|
@ -6,18 +6,21 @@ redis_client = redis.Redis()
|
|||||||
|
|
||||||
def init_app(app):
|
def init_app(app):
|
||||||
connection_class = Connection
|
connection_class = Connection
|
||||||
if app.config.get('REDIS_USE_SSL'):
|
if app.config.get("REDIS_USE_SSL"):
|
||||||
connection_class = SSLConnection
|
connection_class = SSLConnection
|
||||||
|
|
||||||
redis_client.connection_pool = redis.ConnectionPool(**{
|
redis_client.connection_pool = redis.ConnectionPool(
|
||||||
'host': app.config.get('REDIS_HOST'),
|
**{
|
||||||
'port': app.config.get('REDIS_PORT'),
|
"host": app.config.get("REDIS_HOST"),
|
||||||
'username': app.config.get('REDIS_USERNAME'),
|
"port": app.config.get("REDIS_PORT"),
|
||||||
'password': app.config.get('REDIS_PASSWORD'),
|
"username": app.config.get("REDIS_USERNAME"),
|
||||||
'db': app.config.get('REDIS_DB'),
|
"password": app.config.get("REDIS_PASSWORD"),
|
||||||
'encoding': 'utf-8',
|
"db": app.config.get("REDIS_DB"),
|
||||||
'encoding_errors': 'strict',
|
"encoding": "utf-8",
|
||||||
'decode_responses': False
|
"encoding_errors": "strict",
|
||||||
}, connection_class=connection_class)
|
"decode_responses": False,
|
||||||
|
},
|
||||||
|
connection_class=connection_class,
|
||||||
|
)
|
||||||
|
|
||||||
app.extensions['redis'] = redis_client
|
app.extensions["redis"] = redis_client
|
||||||
|
@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException
|
|||||||
|
|
||||||
|
|
||||||
def init_app(app):
|
def init_app(app):
|
||||||
if app.config.get('SENTRY_DSN'):
|
if app.config.get("SENTRY_DSN"):
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn=app.config.get('SENTRY_DSN'),
|
dsn=app.config.get("SENTRY_DSN"),
|
||||||
integrations=[
|
integrations=[FlaskIntegration(), CeleryIntegration()],
|
||||||
FlaskIntegration(),
|
|
||||||
CeleryIntegration()
|
|
||||||
],
|
|
||||||
ignore_errors=[HTTPException, ValueError],
|
ignore_errors=[HTTPException, ValueError],
|
||||||
traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0),
|
traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
|
||||||
profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0),
|
profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
|
||||||
environment=app.config.get('DEPLOY_ENV'),
|
environment=app.config.get("DEPLOY_ENV"),
|
||||||
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}"
|
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
|
||||||
)
|
)
|
||||||
|
@ -17,31 +17,19 @@ class Storage:
|
|||||||
self.storage_runner = None
|
self.storage_runner = None
|
||||||
|
|
||||||
def init_app(self, app: Flask):
|
def init_app(self, app: Flask):
|
||||||
storage_type = app.config.get('STORAGE_TYPE')
|
storage_type = app.config.get("STORAGE_TYPE")
|
||||||
if storage_type == 's3':
|
if storage_type == "s3":
|
||||||
self.storage_runner = S3Storage(
|
self.storage_runner = S3Storage(app=app)
|
||||||
app=app
|
elif storage_type == "azure-blob":
|
||||||
)
|
self.storage_runner = AzureStorage(app=app)
|
||||||
elif storage_type == 'azure-blob':
|
elif storage_type == "aliyun-oss":
|
||||||
self.storage_runner = AzureStorage(
|
self.storage_runner = AliyunStorage(app=app)
|
||||||
app=app
|
elif storage_type == "google-storage":
|
||||||
)
|
self.storage_runner = GoogleStorage(app=app)
|
||||||
elif storage_type == 'aliyun-oss':
|
elif storage_type == "tencent-cos":
|
||||||
self.storage_runner = AliyunStorage(
|
self.storage_runner = TencentStorage(app=app)
|
||||||
app=app
|
elif storage_type == "oci-storage":
|
||||||
)
|
self.storage_runner = OCIStorage(app=app)
|
||||||
elif storage_type == 'google-storage':
|
|
||||||
self.storage_runner = GoogleStorage(
|
|
||||||
app=app
|
|
||||||
)
|
|
||||||
elif storage_type == 'tencent-cos':
|
|
||||||
self.storage_runner = TencentStorage(
|
|
||||||
app=app
|
|
||||||
)
|
|
||||||
elif storage_type == 'oci-storage':
|
|
||||||
self.storage_runner = OCIStorage(
|
|
||||||
app=app
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.storage_runner = LocalStorage(app=app)
|
self.storage_runner = LocalStorage(app=app)
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user