Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/base_app_runner.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/workflow_engine_manager.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
takatost 2024-08-16 01:19:29 +08:00
commit db9b0ee985
196 changed files with 4070 additions and 2133 deletions

View File

@ -76,7 +76,7 @@ jobs:
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@ -90,5 +90,6 @@ jobs:
pgvecto-rs
pgvector
chroma
elasticsearch
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@ -6,5 +6,6 @@ yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs."
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch"

View File

@ -45,6 +45,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Ruff formatter check
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff format --check ./api
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."

View File

@ -130,6 +130,12 @@ TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2
# ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1
ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic
# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431

View File

@ -1,6 +1,6 @@
import os
if os.environ.get("DEBUG", "false").lower() != 'true':
if os.environ.get("DEBUG", "false").lower() != "true":
from gevent import monkey
monkey.patch_all()
@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning)
if os.name == "nt":
os.system('tzutil /s "UTC"')
else:
os.environ['TZ'] = 'UTC'
os.environ["TZ"] = "UTC"
time.tzset()
@ -70,13 +70,14 @@ class DifyApp(Flask):
# -------------
config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
# ----------------------------
# Application Factory Function
# ----------------------------
def create_flask_app_with_configs() -> Flask:
"""
create a raw flask app
@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask:
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ''
os.environ[key] = ""
return dify_app
@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask:
def create_app() -> Flask:
app = create_flask_app_with_configs()
app.secret_key = app.config['SECRET_KEY']
app.secret_key = app.config["SECRET_KEY"]
log_handlers = None
log_file = app.config.get('LOG_FILE')
log_file = app.config.get("LOG_FILE")
if log_file:
log_dir = os.path.dirname(log_file)
os.makedirs(log_dir, exist_ok=True)
@ -111,23 +112,24 @@ def create_app() -> Flask:
RotatingFileHandler(
filename=log_file,
maxBytes=1024 * 1024 * 1024,
backupCount=5
backupCount=5,
),
logging.StreamHandler(sys.stdout)
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=app.config.get('LOG_LEVEL'),
format=app.config.get('LOG_FORMAT'),
datefmt=app.config.get('LOG_DATEFORMAT'),
level=app.config.get("LOG_LEVEL"),
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True
force=True,
)
log_tz = app.config.get('LOG_TZ')
log_tz = app.config.get("LOG_TZ")
if log_tz:
from datetime import datetime
import pytz
timezone = pytz.timezone(log_tz)
def time_converter(seconds):
@ -162,24 +164,24 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in ['console', 'inner_api']:
if request.blueprint not in ["console", "inner_api"]:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get('_token')
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
raise Unauthorized("Invalid Authorization token.")
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
user_id = decoded.get("user_id")
account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
if account:
@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login):
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
return Response(json.dumps({
'code': 'unauthorized',
'message': "Unauthorized."
}), status=401, content_type="application/json")
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
content_type="application/json",
)
# register blueprint routers
@ -204,38 +207,36 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
CORS(web_bp,
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
CORS(
web_bp,
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(web_bp)
CORS(console_app_bp,
resources={
r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
CORS(
console_app_bp,
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(console_app_bp)
CORS(files_bp,
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
@ -245,29 +246,29 @@ def register_blueprints(app):
app = create_app()
celery = app.extensions["celery"]
if app.config.get('TESTING'):
if app.config.get("TESTING"):
print("App is running in TESTING mode")
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
response.set_cookie("remember_token", "", expires=0)
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
return response
@app.route('/health')
@app.route("/health")
def health():
return Response(json.dumps({
'pid': os.getpid(),
'status': 'ok',
'version': app.config['CURRENT_VERSION']
}), status=200, content_type="application/json")
return Response(
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
status=200,
content_type="application/json",
)
@app.route('/threads')
@app.route("/threads")
def threads():
num_threads = threading.active_count()
threads = threading.enumerate()
@ -278,32 +279,34 @@ def threads():
thread_id = thread.ident
is_alive = thread.is_alive()
thread_list.append({
'name': thread_name,
'id': thread_id,
'is_alive': is_alive
})
thread_list.append(
{
"name": thread_name,
"id": thread_id,
"is_alive": is_alive,
}
)
return {
'pid': os.getpid(),
'thread_num': num_threads,
'threads': thread_list
"pid": os.getpid(),
"thread_num": num_threads,
"threads": thread_list,
}
@app.route('/db-pool-stat')
@app.route("/db-pool-stat")
def pool_stat():
engine = db.engine
return {
'pid': os.getpid(),
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
"pid": os.getpid(),
"pool_size": engine.pool.size(),
"checked_in_connections": engine.pool.checkedin(),
"checked_out_connections": engine.pool.checkedout(),
"overflow_connections": engine.pool.overflow(),
"connection_timeout": engine.pool.timeout(),
"recycle_time": db.engine.pool._recycle,
}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
@click.command('reset-password', help='Reset the account password.')
@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset')
@click.option('--new-password', prompt=True, help='the new password.')
@click.option('--password-confirm', prompt=True, help='the new password confirm.')
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option("--new-password", prompt=True, help="the new password.")
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style('sorry. The two passwords do not match.', fg='red'))
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(
click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red'))
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
return
# generate password salt
@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
@click.command('reset-email', help='Reset the account email.')
@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset')
@click.option('--new-email', prompt=True, help='the new email.')
@click.option('--email-confirm', prompt=True, help='the new email confirm.')
@click.command("reset-email", help="Reset the account email.")
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option("--new-email", prompt=True, help="the new email.")
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red'))
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
return
account = db.session.query(Account). \
filter(Account.email == email). \
one_or_none()
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red'))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(
click.style('sorry. {} is not a valid email. '.format(email), fg='red'))
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style('Congratulations!, email has been reset.', fg='green'))
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. '
'After the reset, all LLM credentials will become invalid, '
'requiring re-entry.'
'Only support SELF_HOSTED mode.')
@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?'
' this operation cannot be rolled back!', fg='red'))
@click.command(
"reset-encrypt-key-pair",
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
"After the reset, all LLM credentials will become invalid, "
"requiring re-entry."
"Only support SELF_HOSTED mode.",
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
"""
Reset the encrypted key pair of workspace for encrypt LLM credentials.
After the reset, all LLM credentials will become invalid, requiring re-entry.
Only support SELF_HOSTED mode.
"""
if dify_config.EDITION != 'SELF_HOSTED':
click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red'))
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red'))
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete()
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(click.style('Congratulations! '
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
click.echo(
click.style(
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@click.command('vdb-migrate', help='migrate vector db.')
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']:
if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database()
if scope in ['annotation', 'all']:
if scope in ["annotation", "all"]:
migrate_annotation_vector_database()
@ -146,7 +150,7 @@ def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style('Start migrate annotation data.', fg='green'))
click.echo(click.style("Start migrate annotation data.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -154,98 +158,103 @@ def migrate_annotation_vector_database():
while True:
try:
# get apps info
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
apps = (
db.session.query(App)
.filter(App.status == "normal")
.order_by(App.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
click.echo(
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id))
click.echo("App annotation setting is disabled: {}".format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
click.echo("App annotation collection binding is not exist: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
collection_binding_id=dataset_collection_binding.id,
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
raise e
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
click.style(
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
click.echo(f'Successfully migrated app annotation {app.id}.')
click.echo(f"Successfully migrated app annotation {app.id}.")
create_count += 1
except Exception as e:
click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style(
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
continue
click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
click.style(
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start migrate vector db.', fg='green'))
click.echo(click.style("Start migrate vector db.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -253,87 +262,77 @@ def migrate_knowledge_vector_database():
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
datasets = (
db.session.query(Dataset)
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
)
except NotFound:
break
page += 1
for dataset in datasets:
total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
click.echo(
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
click.echo("Create dataset vdb index: {}".format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ''
collection_name = ""
if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
raise ValueError("Dataset Collection Bindings is not exist!")
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.RELYT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'relyt',
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name}
}
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name}
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ANALYTICDB:
@ -341,9 +340,14 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.ANALYTICDB,
"vector_store": {"class_prefix": collection_name}
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.ELASTICSEARCH:
dataset_id = dataset.id
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")
@ -353,29 +357,41 @@ def migrate_knowledge_vector_database():
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
click.style(
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
click.style(
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
)
)
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
for segment in segments:
document = Document(
@ -385,7 +401,7 @@ def migrate_knowledge_vector_database():
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
},
)
documents.append(document)
@ -393,37 +409,43 @@ def migrate_knowledge_vector_database():
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
click.echo(
click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e
db.session.add(dataset)
db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.')
click.echo(f"Successfully migrated dataset {dataset.id}.")
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
continue
click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
click.style(
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.')
@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.")
def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style('Start convert to agent apps.', fg='green'))
click.echo(click.style("Start convert to agent apps.", fg="green"))
proceeded_app_ids = []
@ -458,7 +480,7 @@ def convert_to_agent_apps():
break
for app in apps:
click.echo('Converting app: {}'.format(app.id))
click.echo("Converting app: {}".format(app.id))
try:
app.mode = AppMode.AGENT_CHAT.value
@ -470,137 +492,139 @@ def convert_to_agent_apps():
)
db.session.commit()
click.echo(click.style('Converted app: {}'.format(app.id), fg='green'))
click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
except Exception as e:
click.echo(
click.style('Convert app error: {} {}'.format(e.__class__.__name__,
str(e)), fg='red'))
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
def add_qdrant_doc_id_index(field: str):
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant":
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError('Qdrant url is required.')
raise ValueError("Qdrant url is required.")
qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
)
try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# create payload index
client.create_payload_index(binding.collection_name, field,
field_schema=PayloadSchemaType.KEYWORD)
client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD)
create_count += 1
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
click.echo(
click.style(
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
except Exception as e:
click.echo(click.style('Failed to create qdrant client.', fg='red'))
click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo(
click.style(f'Congratulations! Create {create_count} collection indexes.',
fg='green'))
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
@click.command('create-tenant', help='Create account and tenant.')
@click.option('--email', prompt=True, help='The email address of the tenant account.')
@click.option('--language', prompt=True, help='Account language, default: en-US.')
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style('Sorry, email is required.', fg='red'))
click.echo(click.style("Sorry, email is required.", fg="red"))
return
# Create account
email = email.strip()
if '@' not in email:
click.echo(click.style('Sorry, invalid email address.', fg='red'))
if "@" not in email:
click.echo(click.style("Sorry, invalid email address.", fg="red"))
return
account_name = email.split('@')[0]
account_name = email.split("@")[0]
if language not in languages:
language = 'en-US'
language = "en-US"
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language
)
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(click.style('Congratulations! Account and tenant created.\n'
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
click.echo(
click.style(
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
@click.command('upgrade-db', help='upgrade the database')
@click.command("upgrade-db", help="upgrade the database")
def upgrade_db():
click.echo('Preparing database migration...')
lock = redis_client.lock(name='db_upgrade_lock', timeout=60)
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style('Start database migration.', fg='green'))
click.echo(click.style("Start database migration.", fg="green"))
# run db migration
import flask_migrate
flask_migrate.upgrade()
click.echo(click.style('Database migration successful!', fg='green'))
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
logging.exception(f'Database migration failed, error: {e}')
logging.exception(f"Database migration failed, error: {e}")
finally:
lock.release()
else:
click.echo('Database migration skipped')
click.echo("Database migration skipped")
@click.command('fix-app-site-missing', help='Fix app related site missing issue.')
@click.command("fix-app-site-missing", help="Fix app related site missing issue.")
def fix_app_site_missing():
"""
Fix app related site missing issue.
"""
click.echo(click.style('Start fix app related site missing issue.', fg='green'))
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
failed_app_ids = []
while True:
@ -631,15 +655,14 @@ where sites.id is null limit 1000"""
app_was_created.send(app, account=account)
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red'))
logging.exception(f'Fix app related site missing issue failed, error: {e}')
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
continue
if not processed_count:
break
click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green'))
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
def register_commands(app):

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.6.16',
default='0.7.0',
)
COMMIT_SHA: str = Field(

View File

@ -1 +1 @@
HIDDEN_VALUE = '[__HIDDEN__]'
HIDDEN_VALUE = "[__HIDDEN__]"

View File

@ -1,22 +1,22 @@
language_timezone_mapping = {
'en-US': 'America/New_York',
'zh-Hans': 'Asia/Shanghai',
'zh-Hant': 'Asia/Taipei',
'pt-BR': 'America/Sao_Paulo',
'es-ES': 'Europe/Madrid',
'fr-FR': 'Europe/Paris',
'de-DE': 'Europe/Berlin',
'ja-JP': 'Asia/Tokyo',
'ko-KR': 'Asia/Seoul',
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
'vi-VN': 'Asia/Ho_Chi_Minh',
'ro-RO': 'Europe/Bucharest',
'pl-PL': 'Europe/Warsaw',
'hi-IN': 'Asia/Kolkata',
'tr-TR': 'Europe/Istanbul',
'fa-IR': 'Asia/Tehran',
"en-US": "America/New_York",
"zh-Hans": "Asia/Shanghai",
"zh-Hant": "Asia/Taipei",
"pt-BR": "America/Sao_Paulo",
"es-ES": "Europe/Madrid",
"fr-FR": "Europe/Paris",
"de-DE": "Europe/Berlin",
"ja-JP": "Asia/Tokyo",
"ko-KR": "Asia/Seoul",
"ru-RU": "Europe/Moscow",
"it-IT": "Europe/Rome",
"uk-UA": "Europe/Kyiv",
"vi-VN": "Asia/Ho_Chi_Minh",
"ro-RO": "Europe/Bucharest",
"pl-PL": "Europe/Warsaw",
"hi-IN": "Asia/Kolkata",
"tr-TR": "Europe/Istanbul",
"fa-IR": "Asia/Tehran",
}
languages = list(language_timezone_mapping.keys())
@ -26,6 +26,5 @@ def supported_language(lang):
if lang in languages:
return lang
error = ('{lang} is not a valid language.'
.format(lang=lang))
error = "{lang} is not a valid language.".format(lang=lang)
raise ValueError(error)

View File

@ -5,82 +5,79 @@ from models.model import AppMode
default_app_templates = {
# workflow default mode
AppMode.WORKFLOW: {
'app': {
'mode': AppMode.WORKFLOW.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.WORKFLOW.value,
"enable_site": True,
"enable_api": True,
}
},
# completion default mode
AppMode.COMPLETION: {
'app': {
'mode': AppMode.COMPLETION.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.COMPLETION.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
"completion_params": {},
},
'user_input_form': json.dumps([
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
"user_input_form": json.dumps(
[
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": "",
},
},
]
),
"pre_prompt": "{{query}}",
},
},
# chat default mode
AppMode.CHAT: {
'app': {
'mode': AppMode.CHAT.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.CHAT.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
}
}
"completion_params": {},
},
},
},
# advanced-chat default mode
AppMode.ADVANCED_CHAT: {
'app': {
'mode': AppMode.ADVANCED_CHAT.value,
'enable_site': True,
'enable_api': True
}
"app": {
"mode": AppMode.ADVANCED_CHAT.value,
"enable_site": True,
"enable_api": True,
},
},
# agent-chat default mode
AppMode.AGENT_CHAT: {
'app': {
'mode': AppMode.AGENT_CHAT.value,
'enable_site': True,
'enable_api': True
"app": {
"mode": AppMode.AGENT_CHAT.value,
"enable_site": True,
"enable_api": True,
},
'model_config': {
'model': {
"model_config": {
"model": {
"provider": "openai",
"name": "gpt-4o",
"mode": "chat",
"completion_params": {}
}
}
}
"completion_params": {},
},
},
},
}

View File

@ -1,3 +1,7 @@
from contextvars import ContextVar
tenant_id: ContextVar[str] = ContextVar('tenant_id')
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")

View File

@ -33,7 +33,7 @@ class CompletionConversationApi(Resource):
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
def get(self, app_model):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('keyword', type=str, location='args')
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
def get(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
@ -119,7 +119,7 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def delete(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
@ -256,7 +256,7 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
def delete(self, app_model, conversation_id):
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)

View File

@ -555,7 +555,7 @@ class DatasetRetrievalSettingApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
@ -579,7 +579,7 @@ class DatasetRetrievalSettingMockApi(Resource):
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,

View File

@ -178,11 +178,20 @@ class DatasetDocumentListApi(Resource):
.subquery()
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
.order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)))
.order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == 'created_at':
query = query.order_by(sort_logic(Document.created_at))
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
else:
query = query.order_by(desc(Document.created_at))
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)

View File

@ -93,6 +93,7 @@ class DatasetConfigManager:
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs["reranking_mode"],
)
)

View File

@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
@ -121,7 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
stream=stream
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
@ -141,10 +141,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
@ -191,7 +191,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
@ -232,8 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'user': user,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
})
worker_thread.start()
@ -246,7 +245,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(
@ -259,7 +258,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user: Account,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
@ -307,14 +305,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) \
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@ -334,7 +335,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
try:

View File

@ -3,9 +3,6 @@ import os
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -94,7 +91,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,

View File

@ -44,7 +44,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
@ -69,14 +69,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariable, Any]
def __init__(
self,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
SystemVariable.USER_ID: user_id,
}
self._task_state = WorkflowTaskState()
@ -127,7 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@ -239,7 +239,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
for queue_message in self._queue_manager.listen():
event = queue_message.event
@ -270,9 +270,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
workflow_run=workflow_run,
event=event
)
@ -307,7 +307,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -316,7 +316,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -325,7 +325,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -334,10 +334,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -360,10 +360,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -400,7 +400,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
@ -413,7 +413,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
@ -421,7 +421,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
@ -446,7 +446,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
@ -458,11 +458,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response()
else:
continue
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
@ -507,7 +507,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
@ -126,7 +128,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
@ -366,7 +368,7 @@ class AppRunner:
message_id=message_id,
trace_manager=app_generate_entity.trace_manager
)
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
@ -418,7 +420,7 @@ class AppRunner:
inputs=inputs,
query=query
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,

View File

@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
def _get_conversation(self, conversation_id: str):
"""
Get conversation by conversation id
:param conversation_id: conversation id
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
.first()
)
if not conversation:
raise ConversationNotExistsError()
return conversation
def _get_message(self, message_id: str) -> Message:

View File

@ -11,7 +11,8 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db

View File

@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -179,7 +181,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
@ -246,7 +248,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
workflow_run=workflow_run,
event=event
)
@ -281,7 +283,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -290,7 +292,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -299,7 +301,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -308,10 +310,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -332,10 +334,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,

View File

@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -13,11 +12,9 @@ from .segments import (
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
@ -32,7 +29,6 @@ __all__ = [
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'FileVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
@ -45,11 +41,9 @@ __all__ = [
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'FileSegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]

View File

@ -2,12 +2,10 @@ from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.file.file_obj import FileVar
from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -17,11 +15,9 @@ from .segments import (
)
from .types import SegmentType
from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@ -5,8 +5,6 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from core.file.file_obj import FileVar
from .types import SegmentType
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
value: int
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class ObjectSegment(Segment):
@ -108,7 +99,13 @@ class ObjectSegment(Segment):
class ArraySegment(Segment):
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
items = []
for item in self.value:
if hasattr(item, 'to_markdown'):
items.append(item.to_markdown())
else:
items.append(str(item))
return '\n'.join(items)
class ArrayAnySegment(ArraySegment):
@ -130,7 +127,3 @@ class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[FileSegment]

View File

@ -10,8 +10,6 @@ class SegmentType(str, Enum):
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object'
FILE = 'file'
GROUP = 'group'

View File

@ -4,11 +4,9 @@ from core.helper import encrypter
from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
pass
class FileVariable(FileSegment, Variable):
pass
class ObjectVariable(ObjectSegment, Variable):
pass
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class ArrayFileVariable(ArrayFileSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET

View File

@ -99,7 +99,7 @@ class MessageFileParser:
# return all file objs
return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
"""
transform message files
@ -144,7 +144,7 @@ class MessageFileParser:
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
"""
transform file to file obj

View File

@ -1,4 +1,3 @@
from core.model_runtime.entities.model_entities import DefaultParameterName
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
@ -94,5 +93,16 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
},
'required': False,
'options': ['JSON', 'XML'],
}
}
},
DefaultParameterName.JSON_SCHEMA: {
'label': {
'en_US': 'JSON Schema',
},
'type': 'text',
'help': {
'en_US': 'Set a response json schema will ensure LLM to adhere it.',
'zh_Hans': '设置返回的json schemallm将按照它返回',
},
'required': False,
},
}

View File

@ -95,6 +95,7 @@ class DefaultParameterName(Enum):
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
JSON_SCHEMA = "json_schema"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
@ -118,6 +119,7 @@ class ParameterType(Enum):
INT = "int"
STRING = "string"
BOOLEAN = "boolean"
TEXT = "text"
class ModelPropertyKey(Enum):

View File

@ -2,6 +2,7 @@
- gpt-4o
- gpt-4o-2024-05-13
- gpt-4o-2024-08-06
- chatgpt-4o-latest
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- gpt-4-turbo

View File

@ -0,0 +1,44 @@
model: chatgpt-4o-latest
label:
zh_Hans: chatgpt-4o-latest
en_US: chatgpt-4o-latest
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '2.50'
output: '10.00'
unit: '0.000001'
currency: USD

View File

@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '2.50'
output: '10.00'

View File

@ -37,6 +37,9 @@ parameter_rules:
options:
- text
- json_object
- json_schema
- name: json_schema
use_template: json_schema
pricing:
input: '0.15'
output: '0.60'

View File

@ -1,3 +1,4 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
@ -544,13 +545,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_object":
response_format = {"type": "json_object"}
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not currect json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
response_format = {"type": "text"}
model_parameters["response_format"] = response_format
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}
@ -922,11 +928,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model.startswith('ft:'):
model = model.split(':')[1]
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
if model == "chatgpt-4o-latest":
model = "gpt-4o"
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@ -946,7 +955,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"See https://platform.openai.com/docs/advanced-usage/managing-tokens for "
"information on how messages are converted to tokens."
)
num_tokens = 0

View File

@ -0,0 +1,61 @@
model: Llama3-Chinese_v2
label:
en_US: Llama3-Chinese_v2
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3-70B-Instruct-GPTQ-Int4
label:
en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 1024
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3-8B-Instruct
label:
en_US: Meta-Llama-3-8B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
label:
en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 410960
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Meta-Llama-3.1-8B-Instruct
label:
en_US: Meta-Llama-3.1-8B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.1
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -55,7 +55,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB
deprecated: true

View File

@ -55,7 +55,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB
deprecated: true

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
@ -55,7 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: completion
context_size: 8192
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
@ -55,7 +55,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -8,12 +8,12 @@ features:
- stream-tool-call
model_properties:
mode: chat
context_size: 8192
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
default: 0.7
min: 0.0
max: 2.0
help:
@ -57,7 +57,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: Qwen2-72B-Instruct
label:
en_US: Qwen2-72B-Instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -8,7 +8,7 @@ features:
- stream-tool-call
model_properties:
mode: completion
context_size: 8192
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
@ -57,7 +57,7 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: '0.000'
output: '0.000'
unit: '0.000'
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -1,6 +1,15 @@
- Meta-Llama-3.1-405B-Instruct-AWQ-INT4
- Meta-Llama-3.1-8B-Instruct
- Meta-Llama-3-70B-Instruct-GPTQ-Int4
- Meta-Llama-3-8B-Instruct
- Qwen2-72B-Instruct-GPTQ-Int4
- Qwen2-72B-Instruct
- Qwen2-7B
- Qwen1.5-110B-Chat-GPTQ-Int4
- Qwen-14B-Chat-Int4
- Qwen1.5-72B-Chat-GPTQ-Int4
- Qwen1.5-7B
- Qwen-14B-Chat-Int4
- Qwen1.5-110B-Chat-GPTQ-Int4
- deepseek-v2-chat
- deepseek-v2-lite-chat
- Llama3-Chinese_v2
- chatglm3-6b

View File

@ -0,0 +1,61 @@
model: chatglm3-6b
label:
en_US: chatglm3-6b
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: deepseek-v2-chat
label:
en_US: deepseek-v2-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,61 @@
model: deepseek-v2-lite-chat
label:
en_US: deepseek-v2-lite-chat
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 2048
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.5
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 600
min: 1
max: 1248
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
pricing:
input: "0.000"
output: "0.000"
unit: "0.000"
currency: RMB

View File

@ -0,0 +1,4 @@
model: BAAI/bge-large-en-v1.5
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -0,0 +1,4 @@
model: BAAI/bge-large-zh-v1.5
model_type: text-embedding
model_properties:
context_size: 32768

View File

@ -0,0 +1,81 @@
model: farui-plus
label:
en_US: farui-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 12288
parameter_rules:
- name: temperature
use_template: temperature
type: float
default: 0.3
min: 0.0
max: 2.0
help:
zh_Hans: 用于控制随机性和多样性的程度。具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定。
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
- name: max_tokens
use_template: max_tokens
type: int
default: 2000
min: 1
max: 2000
help:
zh_Hans: 用于指定模型在生成内容时token的最大数量它定义了生成的上限但不保证每次都会生成到这个数量。
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
- name: top_p
use_template: top_p
type: float
default: 0.8
min: 0.1
max: 0.9
help:
zh_Hans: 生成过程中核采样方法概率阈值例如取值为0.8时仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
- name: top_k
type: int
min: 0
max: 99
label:
zh_Hans: 取样数量
en_US: Top k
help:
zh_Hans: 生成时采样候选集的大小。例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大生成的随机性越高取值越小生成的确定性越高。
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
- name: seed
required: false
type: int
default: 1234
label:
zh_Hans: 随机种子
en_US: Random seed
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: repetition_penalty
required: false
type: float
default: 1.1
label:
en_US: Repetition penalty
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
- name: enable_search
type: boolean
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
use_template: response_format
pricing:
input: '0.02'
output: '0.02'
unit: '0.001'
currency: RMB

View File

@ -159,6 +159,8 @@ You should also complete the text started with ``` but not tell ``` directly.
"""
if model in ['qwen-turbo-chat', 'qwen-plus-chat']:
model = model.replace('-chat', '')
if model == 'farui-plus':
model = 'qwen-farui-plus'
if model in self.tokenizers:
tokenizer = self.tokenizers[model]

View File

@ -1 +1 @@
- soloar-1-mini-chat
- solar-1-mini-chat

View File

@ -1,11 +1,10 @@
import enum
import json
import os
from typing import Optional
from typing import TYPE_CHECKING, Optional
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import AppMode
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class ModelMode(enum.Enum):
COMPLETION = 'completion'
@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list[FileVar],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> \
@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list[FileVar],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list[FileVar],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
return [self.get_last_user_message(prompt, files)], stops
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:

View File

@ -0,0 +1,191 @@
import json
from typing import Any
import requests
from elasticsearch import Elasticsearch
from flask import current_app
from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from models.dataset import Dataset
class ElasticSearchConfig(BaseModel):
host: str
port: str
username: str
password: str
@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config HOST is required")
if not values['port']:
raise ValueError("config PORT is required")
if not values['username']:
raise ValueError("config USERNAME is required")
if not values['password']:
raise ValueError("config PASSWORD is required")
return values
class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower())
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
client = Elasticsearch(
hosts=f'{config.host}:{config.port}',
basic_auth=(config.username, config.password),
request_timeout=100000,
retry_on_timeout=True,
max_retries=10000,
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
return client
def get_type(self) -> str:
return 'elasticsearch'
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)
added_ids = []
for i, text in enumerate(texts):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
"text": text,
"vector": embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {},
})
added_ids.append(uuids[i])
self._client.indices.refresh(index=self._collection_name)
return uuids
def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__()
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:
self._client.delete(index=self._collection_name, id=id)
def delete_by_metadata_field(self, key: str, value: str) -> None:
query_str = {
'query': {
'match': {
f'metadata.{key}': f'{value}'
}
}
}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit['_id'] for hit in results['hits']['hits']]
if ids:
self.delete_by_ids(ids)
def delete(self) -> None:
self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
}
results = self._client.search(index=self._collection_name, body=query_str)
docs_and_scores = []
for hit in results['hits']['hits']:
docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
"text": query
}
}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs)
class ElasticSearchVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
config = current_app.config
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get('ELASTICSEARCH_HOST'),
port=config.get('ELASTICSEARCH_PORT'),
username=config.get('ELASTICSEARCH_USERNAME'),
password=config.get('ELASTICSEARCH_PASSWORD'),
),
attributes=[]
)

View File

@ -93,7 +93,7 @@ class MyScaleVector(BaseVector):
@staticmethod
def escape_str(value: Any) -> str:
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
return "".join(" " if c in ("\\", "'") else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
@ -118,7 +118,7 @@ class MyScaleVector(BaseVector):
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)

View File

@ -71,6 +71,9 @@ class Vector:
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory

View File

@ -15,3 +15,4 @@ class VectorType(str, Enum):
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
ELASTICSEARCH = 'elasticsearch'

View File

@ -46,7 +46,7 @@ class ToolProviderType(Enum):
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ApiProviderSchemaType(Enum):
"""
Enum class for api provider schema type.
@ -68,7 +68,7 @@ class ApiProviderSchemaType(Enum):
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ApiProviderAuthType(Enum):
"""
Enum class for api provider auth type.
@ -103,8 +103,8 @@ class ToolInvokeMessage(BaseModel):
"""
plain text, image url or link url
"""
message: Union[str, bytes, dict] = None
meta: dict[str, Any] = None
message: str | bytes | dict | None = None
meta: dict[str, Any] | None = None
save_as: str = ''
class ToolInvokeMessageBinary(BaseModel):
@ -154,8 +154,8 @@ class ToolParameter(BaseModel):
options: Optional[list[ToolParameterOption]] = None
@classmethod
def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType,
def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType,
required: bool, options: Optional[list[str]] = None) -> 'ToolParameter':
"""
get a simple tool parameter
@ -222,7 +222,7 @@ class ToolProviderCredentials(BaseModel):
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def default(value: str) -> str:
return ""
@ -290,7 +290,7 @@ class ToolRuntimeVariablePool(BaseModel):
'tenant_id': self.tenant_id,
'pool': [variable.model_dump() for variable in self.pool],
}
def set_text(self, tool_name: str, name: str, value: str) -> None:
"""
set a text variable
@ -301,7 +301,7 @@ class ToolRuntimeVariablePool(BaseModel):
variable = cast(ToolRuntimeTextVariable, variable)
variable.value = value
return
variable = ToolRuntimeTextVariable(
type=ToolRuntimeVariableType.TEXT,
name=name,
@ -334,7 +334,7 @@ class ToolRuntimeVariablePool(BaseModel):
variable = cast(ToolRuntimeImageVariable, variable)
variable.value = value
return
variable = ToolRuntimeImageVariable(
type=ToolRuntimeVariableType.IMAGE,
name=name,
@ -388,21 +388,21 @@ class ToolInvokeMeta(BaseModel):
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
"""
Get an instance of ToolInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
'tool_config': self.tool_config,
}
class ToolLabel(BaseModel):
"""
Tool label
@ -416,4 +416,4 @@ class ToolInvokeFrom(Enum):
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"
AGENT = "agent"

View File

@ -0,0 +1,2 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="24" height="25" viewBox="0 0 24 25" xmlns="http://www.w3.org/2000/svg" fill="none"><path fill="#FC6D26" d="M14.975 8.904L14.19 6.55l-1.552-4.67a.268.268 0 00-.255-.18.268.268 0 00-.254.18l-1.552 4.667H5.422L3.87 1.879a.267.267 0 00-.254-.179.267.267 0 00-.254.18l-1.55 4.667-.784 2.357a.515.515 0 00.193.583l6.78 4.812 6.778-4.812a.516.516 0 00.196-.583z"/><path fill="#E24329" d="M8 14.296l2.578-7.75H5.423L8 14.296z"/><path fill="#FC6D26" d="M8 14.296l-2.579-7.75H1.813L8 14.296z"/><path fill="#FCA326" d="M1.81 6.549l-.784 2.354a.515.515 0 00.193.583L8 14.3 1.81 6.55z"/><path fill="#E24329" d="M1.812 6.549h3.612L3.87 1.882a.268.268 0 00-.254-.18.268.268 0 00-.255.18L1.812 6.549z"/><path fill="#FC6D26" d="M8 14.296l2.578-7.75h3.614L8 14.296z"/><path fill="#FCA326" d="M14.19 6.549l.783 2.354a.514.514 0 01-.193.583L8 14.296l6.188-7.747h.001z"/><path fill="#E24329" d="M14.19 6.549H10.58l1.551-4.667a.267.267 0 01.255-.18c.115 0 .217.073.254.18l1.552 4.667z"/></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@ -0,0 +1,34 @@
from typing import Any
import requests
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class GitlabProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.")
if 'site_url' not in credentials or not credentials.get('site_url'):
site_url = 'https://gitlab.com'
else:
site_url = credentials.get('site_url')
try:
headers = {
"Content-Type": "application/vnd.text+json",
"Authorization": f"Bearer {credentials.get('access_tokens')}",
}
response = requests.get(
url= f"{site_url}/api/v4/user",
headers=headers)
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
except Exception as e:
raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,38 @@
identity:
author: Leo.Wang
name: gitlab
label:
en_US: Gitlab
zh_Hans: Gitlab
description:
en_US: Gitlab plugin for commit
zh_Hans: 用于获取Gitlab commit的插件
icon: gitlab.svg
credentials_for_provider:
access_tokens:
type: secret-input
required: true
label:
en_US: Gitlab access token
zh_Hans: Gitlab access token
placeholder:
en_US: Please input your Gitlab access token
zh_Hans: 请输入你的 Gitlab access token
help:
en_US: Get your Gitlab access token from Gitlab
zh_Hans: 从 Gitlab 获取您的 access token
url: https://docs.gitlab.com/16.9/ee/api/oauth2.html
site_url:
type: text-input
required: false
default: 'https://gitlab.com'
label:
en_US: Gitlab site url
zh_Hans: Gitlab site url
placeholder:
en_US: Please input your Gitlab site url
zh_Hans: 请输入你的 Gitlab site url
help:
en_US: Find your Gitlab url
zh_Hans: 找到你的Gitlab url
url: https://gitlab.com/help

View File

@ -0,0 +1,101 @@
import json
from datetime import datetime, timedelta
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GitlabCommitsTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get('project', '')
employee = tool_parameters.get('employee', '')
start_time = tool_parameters.get('start_time', '')
end_time = tool_parameters.get('end_time', '')
if not project:
return self.create_text_message('Project is required')
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
# Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time)
return self.create_text_message(json.dumps(result, ensure_ascii=False))
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# Get all of projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
filtered_projects = [p for p in projects if project == "*" or p['name'] == project]
for project in filtered_projects:
project_id = project['id']
project_name = project['name']
print(f"Project: {project_name}")
# Get all of proejct commits
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
params = {
'since': start_time,
'until': end_time
}
if employee:
params['author'] = employee
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
for commit in commits:
commit_sha = commit['id']
print(f"\tCommit SHA: {commit_sha}")
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
for diff in diffs:
# Caculate code lines of changed
added_lines = diff['diff'].count('\n+')
removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
results.append({
"project": project_name,
"commit_sha": commit_sha,
"diff": final_code
})
print(f"Commit code:{final_code}")
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -0,0 +1,56 @@
identity:
name: gitlab_commits
author: Leo.Wang
label:
en_US: Gitlab Commits
zh_Hans: Gitlab代码提交内容
description:
human:
en_US: A tool for query gitlab commits. Input should be a exists username.
zh_Hans: 一个用于查询gitlab代码提交记录的的工具输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query gitlab commits. Input should be a exists username or project.
parameters:
- name: employee
type: string
required: false
label:
en_US: employee
zh_Hans: 员工用户名
human_description:
en_US: employee
zh_Hans: 员工用户名
llm_description: employee for gitlab
form: llm
- name: project
type: string
required: true
label:
en_US: project
zh_Hans: 项目名
human_description:
en_US: project
zh_Hans: 项目名
llm_description: project for gitlab
form: llm
- name: start_time
type: string
required: false
label:
en_US: start_time
zh_Hans: 开始时间
human_description:
en_US: start_time
zh_Hans: 开始时间
llm_description: start_time for gitlab
form: llm
- name: end_time
type: string
required: false
label:
en_US: end_time
zh_Hans: 结束时间
human_description:
en_US: end_time
zh_Hans: 结束时间
llm_description: end_time for gitlab
form: llm

View File

@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping
from copy import deepcopy
from enum import Enum
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileVar
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class Tool(BaseModel, ABC):
identity: Optional[ToolIdentity] = None
@ -76,7 +78,7 @@ class Tool(BaseModel, ABC):
description=self.description.model_copy() if self.description else None,
runtime=Tool.Runtime(**runtime),
)
@abstractmethod
def tool_provider_type(self) -> ToolProviderType:
"""
@ -84,7 +86,7 @@ class Tool(BaseModel, ABC):
:return: the tool provider type
"""
def load_variables(self, variables: ToolRuntimeVariablePool):
"""
load variables from database
@ -99,7 +101,7 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return
self.variables.set_file(self.identity.name, variable_name, image_key)
def set_text_variable(self, variable_name: str, text: str) -> None:
@ -108,9 +110,9 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return
self.variables.set_text(self.identity.name, variable_name, text)
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
"""
get a variable
@ -120,14 +122,14 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return None
if isinstance(name, Enum):
name = name.value
for variable in self.variables.pool:
if variable.name == name:
return variable
return None
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
@ -138,9 +140,9 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return None
return self.get_variable(self.VARIABLE_KEY.IMAGE)
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
"""
get a variable file
@ -151,7 +153,7 @@ class Tool(BaseModel, ABC):
variable = self.get_variable(name)
if not variable:
return None
if not isinstance(variable, ToolRuntimeImageVariable):
return None
@ -160,9 +162,9 @@ class Tool(BaseModel, ABC):
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
if not file_binary:
return None
return file_binary[0]
def list_variables(self) -> list[ToolRuntimeVariable]:
"""
list all variables
@ -171,9 +173,9 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return []
return self.variables.pool
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
"""
list all image variables
@ -182,9 +184,9 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return []
result = []
for variable in self.variables.pool:
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
result.append(variable)
@ -225,7 +227,7 @@ class Tool(BaseModel, ABC):
@abstractmethod
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
"""
validate the credentials
@ -244,7 +246,7 @@ class Tool(BaseModel, ABC):
:return: the runtime parameters
"""
return self.parameters or []
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""
get all runtime parameters
@ -278,7 +280,7 @@ class Tool(BaseModel, ABC):
parameters.append(parameter)
return parameters
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
@ -286,18 +288,18 @@ class Tool(BaseModel, ABC):
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
message=image,
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
message=image,
save_as=save_as)
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
message='',
meta={
'file_var': file_var
},
save_as='')
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
@ -305,10 +307,10 @@ class Tool(BaseModel, ABC):
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
message=link,
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
message=link,
save_as=save_as)
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
@ -321,7 +323,7 @@ class Tool(BaseModel, ABC):
message=text,
save_as=save_as
)
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message

View File

@ -1,7 +1,7 @@
import logging
from mimetypes import guess_extension
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.file_obj import FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
@ -28,12 +28,12 @@ class ToolFileMessageTransformer:
# try to download image
try:
file = ToolFileManager.create_file_by_url(
user_id=user_id,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_url=message.message
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage(
@ -56,14 +56,14 @@ class ToolFileMessageTransformer:
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8')
file = ToolFileManager.create_file_by_raw(
user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message,
mimetype=mimetype
)
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
# check if file is image
@ -82,7 +82,7 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {},
))
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var: FileVar = message.meta.get('file_var')
file_var = message.meta.get('file_var')
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
@ -104,7 +104,7 @@ class ToolFileMessageTransformer:
result.append(message)
return result
@classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
return f'/files/tools/{tool_file_id}{extension or ".bin"}'

View File

@ -4,13 +4,14 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = 'start'
END = 'end'
ANSWER = 'answer'
@ -44,33 +45,11 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}')
class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
"""
Node Run Result.
"""
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict[str, Any]] = None # node inputs

View File

@ -7,7 +7,7 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
VariableValue = Union[str, int, float, dict, list, FileVar]

View File

@ -0,0 +1,25 @@
from enum import Enum
class SystemVariable(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

View File

@ -140,9 +140,6 @@ class HttpRequestNode(BaseNode):
"""
files = []
mimetype, file_binary = response.extract_file()
# if not image, return directly
if 'image' not in mimetype:
return files
if mimetype:
# extract filename from url

View File

@ -1,7 +1,7 @@
import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
@ -9,7 +9,6 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@ -24,9 +23,10 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
@ -41,6 +41,10 @@ from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class ModelInvokeCompleted(BaseModel):
"""
@ -81,7 +85,7 @@ class LLMNode(BaseNode):
node_inputs = {}
# fetch files
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
files = self._fetch_files(node_data, variable_pool)
if files:
node_inputs['#files#'] = [file.to_dict() for file in files]
@ -368,7 +372,7 @@ class LLMNode(BaseNode):
return inputs # type: ignore
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
Fetch files
:param node_data: node data
@ -563,7 +567,7 @@ class LLMNode(BaseNode):
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
@ -678,8 +682,8 @@ class LLMNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData
) -> Mapping[str, Sequence[str]]:

View File

@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from core.app.segments import parser
from core.app.segments import ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
@ -140,9 +141,9 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
# FIXME: ensure this is a ArrayVariable contains FileVariable.
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
return [file_var.value for file_var in variable.value] if variable else []
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
-> tuple[str, list[FileVar], list[dict]]:

View File

@ -1,13 +1,13 @@
from blinker import signal
# sender: app
app_was_created = signal('app-was-created')
app_was_created = signal("app-was-created")
# sender: app, kwargs: app_model_config
app_model_config_was_updated = signal('app-model-config-was-updated')
app_model_config_was_updated = signal("app-model-config-was-updated")
# sender: app, kwargs: published_workflow
app_published_workflow_was_updated = signal('app-published-workflow-was-updated')
app_published_workflow_was_updated = signal("app-published-workflow-was-updated")
# sender: app, kwargs: synced_draft_workflow
app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced')
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: dataset
dataset_was_deleted = signal('dataset-was-deleted')
dataset_was_deleted = signal("dataset-was-deleted")

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: document
document_was_deleted = signal('document-was-deleted')
document_was_deleted = signal("document-was-deleted")

View File

@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender, **kwargs):
dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id, dataset.doc_form)
clean_dataset_task.delay(
dataset.id,
dataset.tenant_id,
dataset.indexing_technique,
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
)

View File

@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task
@document_was_deleted.connect
def handle(sender, **kwargs):
document_id = sender
dataset_id = kwargs.get('dataset_id')
doc_form = kwargs.get('doc_form')
file_id = kwargs.get('file_id')
dataset_id = kwargs.get("dataset_id")
doc_form = kwargs.get("doc_form")
file_id = kwargs.get("file_id")
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)

View File

@ -14,21 +14,25 @@ from models.dataset import Document
@document_index_created.connect
def handle(sender, **kwargs):
dataset_id = sender
document_ids = kwargs.get('document_ids', None)
document_ids = kwargs.get("document_ids", None)
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
document = (
db.session.query(Document)
.filter(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
.first()
)
if not document:
raise NotFound('Document not found')
raise NotFound("Document not found")
document.indexing_status = 'parsing'
document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document)
db.session.add(document)
@ -38,8 +42,8 @@ def handle(sender, **kwargs):
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass

View File

@ -10,7 +10,7 @@ def handle(sender, **kwargs):
installed_app = InstalledApp(
tenant_id=app.tenant_id,
app_id=app.id,
app_owner_tenant_id=app.tenant_id
app_owner_tenant_id=app.tenant_id,
)
db.session.add(installed_app)
db.session.commit()

View File

@ -7,15 +7,15 @@ from models.model import Site
def handle(sender, **kwargs):
"""Create site record when an app is created."""
app = sender
account = kwargs.get('account')
account = kwargs.get("account")
site = Site(
app_id=app.id,
title=app.name,
icon = app.icon,
icon_background = app.icon_background,
icon=app.icon,
icon_background=app.icon_background,
default_language=account.interface_language,
customize_token_strategy='not_allow',
code=Site.generate_code(16)
customize_token_strategy="not_allow",
code=Site.generate_code(16),
)
db.session.add(site)

View File

@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get('application_generate_entity')
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
@ -39,7 +39,7 @@ def handle(sender, **kwargs):
elif quota_unit == QuotaUnit.CREDITS:
used_quota = 1
if 'gpt-4' in model_config.model:
if "gpt-4" in model_config.model:
used_quota = 20
else:
used_quota = 1
@ -50,6 +50,6 @@ def handle(sender, **kwargs):
Provider.provider_name == model_config.provider,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + used_quota})
Provider.quota_limit > Provider.quota_used,
).update({"quota_used": Provider.quota_used + used_quota})
db.session.commit()

View File

@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced
@app_draft_workflow_was_synced.connect
def handle(sender, **kwargs):
app = sender
for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []):
if node_data.get('data', {}).get('type') == NodeType.TOOL.value:
for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(
@ -23,7 +23,7 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}',
)
manager.delete_tool_parameters_cache()
except:

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: document
document_index_created = signal('document-index-created')
document_index_created = signal("document-index-created")

View File

@ -7,13 +7,11 @@ from models.model import AppModelConfig
@app_model_config_was_updated.connect
def handle(sender, **kwargs):
app = sender
app_model_config = kwargs.get('app_model_config')
app_model_config = kwargs.get("app_model_config")
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id
).all()
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids = []
if not app_dataset_joins:
@ -29,16 +27,12 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app.id,
dataset_id=dataset_id
)
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
db.session.add(app_dataset_join)
db.session.commit()
@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get('tools', []) or []
tools = agent_mode.get("tools", []) or []
for tool in tools:
if len(list(tool.keys())) != 1:
continue
@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set:
# get dataset from dataset_configs
dataset_configs = app_model_config.dataset_configs_dict
datasets = dataset_configs.get('datasets', {}) or {}
for dataset in datasets.get('datasets', []) or []:
datasets = dataset_configs.get("datasets", {}) or {}
for dataset in datasets.get("datasets", []) or []:
keys = list(dataset.keys())
if len(keys) == 1 and keys[0] == 'dataset':
if dataset['dataset'].get('id'):
dataset_ids.add(dataset['dataset'].get('id'))
if len(keys) == 1 and keys[0] == "dataset":
if dataset["dataset"].get("id"):
dataset_ids.add(dataset["dataset"].get("id"))
return dataset_ids

View File

@ -11,13 +11,11 @@ from models.workflow import Workflow
@app_published_workflow_was_updated.connect
def handle(sender, **kwargs):
app = sender
published_workflow = kwargs.get('published_workflow')
published_workflow = kwargs.get("published_workflow")
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id
).all()
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids = []
if not app_dataset_joins:
@ -33,16 +31,12 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app.id,
dataset_id=dataset_id
)
app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id)
db.session.add(app_dataset_join)
db.session.commit()
@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
if not graph:
return dataset_ids
nodes = graph.get('nodes', [])
nodes = graph.get("nodes", [])
# fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [node for node in nodes
if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value]
knowledge_retrieval_nodes = [
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
]
if not knowledge_retrieval_nodes:
return dataset_ids
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get('data', {}))
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
dataset_ids.update(node_data.dataset_ids)
except Exception as e:
continue

View File

@ -9,13 +9,13 @@ from models.provider import Provider
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get('application_generate_entity')
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.model_conf.provider
).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)})
Provider.provider_name == application_generate_entity.model_conf.provider,
).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)})
db.session.commit()

View File

@ -1,4 +1,4 @@
from blinker import signal
# sender: message, kwargs: conversation
message_was_created = signal('message-was-created')
message_was_created = signal("message-was-created")

View File

@ -1,7 +1,7 @@
from blinker import signal
# sender: tenant
tenant_was_created = signal('tenant-was-created')
tenant_was_created = signal("tenant-was-created")
# sender: tenant
tenant_was_updated = signal('tenant-was-updated')
tenant_was_updated = signal("tenant-was-updated")

View File

@ -17,7 +17,7 @@ def init_app(app: Flask) -> Celery:
backend=app.config["CELERY_BACKEND"],
task_ignore_result=True,
)
# Add SSL options to the Celery configuration
ssl_options = {
"ssl_cert_reqs": None,
@ -35,7 +35,7 @@ def init_app(app: Flask) -> Celery:
celery_app.conf.update(
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
)
celery_app.set_default()
app.extensions["celery"] = celery_app
@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery:
]
day = app.config["CELERY_BEAT_SCHEDULER_TIME"]
beat_schedule = {
'clean_embedding_cache_task': {
'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task',
'schedule': timedelta(days=day),
"clean_embedding_cache_task": {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
"schedule": timedelta(days=day),
},
"clean_unused_datasets_task": {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
},
'clean_unused_datasets_task': {
'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
'schedule': timedelta(days=day),
}
}
celery_app.conf.update(
beat_schedule=beat_schedule,
imports=imports
)
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@ -2,15 +2,14 @@ from flask import Flask
def init_app(app: Flask):
if app.config.get('API_COMPRESSION_ENABLED'):
if app.config.get("API_COMPRESSION_ENABLED"):
from flask_compress import Compress
app.config['COMPRESS_MIMETYPES'] = [
'application/json',
'image/svg+xml',
'text/html',
app.config["COMPRESS_MIMETYPES"] = [
"application/json",
"image/svg+xml",
"text/html",
]
compress = Compress()
compress.init_app(app)

View File

@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
POSTGRES_INDEXES_NAMING_CONVENTION = {
'ix': '%(column_0_label)s_idx',
'uq': '%(table_name)s_%(column_0_name)s_key',
'ck': '%(table_name)s_%(constraint_name)s_check',
'fk': '%(table_name)s_%(column_0_name)s_fkey',
'pk': '%(table_name)s_pkey',
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)

View File

@ -14,67 +14,69 @@ class Mail:
return self._client is not None
def init_app(self, app: Flask):
if app.config.get('MAIL_TYPE'):
if app.config.get('MAIL_DEFAULT_SEND_FROM'):
self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM')
if app.config.get('MAIL_TYPE') == 'resend':
api_key = app.config.get('RESEND_API_KEY')
if not api_key:
raise ValueError('RESEND_API_KEY is not set')
if app.config.get("MAIL_TYPE"):
if app.config.get("MAIL_DEFAULT_SEND_FROM"):
self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM")
api_url = app.config.get('RESEND_API_URL')
if app.config.get("MAIL_TYPE") == "resend":
api_key = app.config.get("RESEND_API_KEY")
if not api_key:
raise ValueError("RESEND_API_KEY is not set")
api_url = app.config.get("RESEND_API_URL")
if api_url:
resend.api_url = api_url
resend.api_key = api_key
self._client = resend.Emails
elif app.config.get('MAIL_TYPE') == 'smtp':
elif app.config.get("MAIL_TYPE") == "smtp":
from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'):
raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS')
if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"):
raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type")
if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"):
raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS")
self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'),
port=app.config.get('SMTP_PORT'),
username=app.config.get('SMTP_USERNAME'),
password=app.config.get('SMTP_PASSWORD'),
_from=app.config.get('MAIL_DEFAULT_SEND_FROM'),
use_tls=app.config.get('SMTP_USE_TLS'),
opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS')
server=app.config.get("SMTP_SERVER"),
port=app.config.get("SMTP_PORT"),
username=app.config.get("SMTP_USERNAME"),
password=app.config.get("SMTP_PASSWORD"),
_from=app.config.get("MAIL_DEFAULT_SEND_FROM"),
use_tls=app.config.get("SMTP_USE_TLS"),
opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"),
)
else:
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE")))
else:
logging.warning('MAIL_TYPE is not set')
logging.warning("MAIL_TYPE is not set")
def send(self, to: str, subject: str, html: str, from_: Optional[str] = None):
if not self._client:
raise ValueError('Mail client is not initialized')
raise ValueError("Mail client is not initialized")
if not from_ and self._default_send_from:
from_ = self._default_send_from
if not from_:
raise ValueError('mail from is not set')
raise ValueError("mail from is not set")
if not to:
raise ValueError('mail to is not set')
raise ValueError("mail to is not set")
if not subject:
raise ValueError('mail subject is not set')
raise ValueError("mail subject is not set")
if not html:
raise ValueError('mail html is not set')
raise ValueError("mail html is not set")
self._client.send({
"from": from_,
"to": to,
"subject": subject,
"html": html
})
self._client.send(
{
"from": from_,
"to": to,
"subject": subject,
"html": html,
}
)
def init_app(app: Flask):

View File

@ -6,18 +6,21 @@ redis_client = redis.Redis()
def init_app(app):
connection_class = Connection
if app.config.get('REDIS_USE_SSL'):
if app.config.get("REDIS_USE_SSL"):
connection_class = SSLConnection
redis_client.connection_pool = redis.ConnectionPool(**{
'host': app.config.get('REDIS_HOST'),
'port': app.config.get('REDIS_PORT'),
'username': app.config.get('REDIS_USERNAME'),
'password': app.config.get('REDIS_PASSWORD'),
'db': app.config.get('REDIS_DB'),
'encoding': 'utf-8',
'encoding_errors': 'strict',
'decode_responses': False
}, connection_class=connection_class)
redis_client.connection_pool = redis.ConnectionPool(
**{
"host": app.config.get("REDIS_HOST"),
"port": app.config.get("REDIS_PORT"),
"username": app.config.get("REDIS_USERNAME"),
"password": app.config.get("REDIS_PASSWORD"),
"db": app.config.get("REDIS_DB"),
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
},
connection_class=connection_class,
)
app.extensions['redis'] = redis_client
app.extensions["redis"] = redis_client

View File

@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException
def init_app(app):
if app.config.get('SENTRY_DSN'):
if app.config.get("SENTRY_DSN"):
sentry_sdk.init(
dsn=app.config.get('SENTRY_DSN'),
integrations=[
FlaskIntegration(),
CeleryIntegration()
],
dsn=app.config.get("SENTRY_DSN"),
integrations=[FlaskIntegration(), CeleryIntegration()],
ignore_errors=[HTTPException, ValueError],
traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0),
profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0),
environment=app.config.get('DEPLOY_ENV'),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}"
traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0),
profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0),
environment=app.config.get("DEPLOY_ENV"),
release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}",
)

Some files were not shown because too many files have changed in this diff Show More