diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index ab585a5ae9..e7a8c98d26 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,5 +1,5 @@ -FROM mcr.microsoft.com/devcontainers/python:3.10 +FROM mcr.microsoft.com/devcontainers/python:3.12 # [Optional] Uncomment this section to install additional OS packages. # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ -# && apt-get -y install --no-install-recommends \ No newline at end of file +# && apt-get -y install --no-install-recommends diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ebc8bf74c1..339ad60ce0 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,7 +1,7 @@ // For format details, see https://aka.ms/devcontainer.json. For config options, see the // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda { - "name": "Python 3.10", + "name": "Python 3.12", "build": { "context": "..", "dockerfile": "Dockerfile" diff --git a/.github/actions/setup-poetry/action.yml b/.github/actions/setup-poetry/action.yml index 5feab33d1d..2e76676f37 100644 --- a/.github/actions/setup-poetry/action.yml +++ b/.github/actions/setup-poetry/action.yml @@ -4,7 +4,7 @@ inputs: python-version: description: Python version to use and the Poetry installed with required: true - default: '3.10' + default: '3.11' poetry-version: description: Poetry version to set up required: true diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 76e844aaad..e1c0bf33a4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -20,7 +20,6 @@ jobs: strategy: matrix: python-version: - - "3.10" - "3.11" - "3.12" diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index caddd23bab..a5ba51ce0e 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -20,7 +20,6 @@ jobs: strategy: matrix: python-version: - - "3.10" - "3.11" - "3.12" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index da2928d189..22261804fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,8 @@ +# CONTRIBUTING + So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly. -We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. +We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve. @@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: -### Feature requests: +### Feature requests * If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try. * If you want to pick one up from the existing issues, simply drop a comment below it saying so. - - A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes. Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment: @@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr | Non-core features and minor enhancements | Low Priority | | Valuable but not immediate | Future-Feature | -### Anything else (e.g. bug report, performance optimization, typo correction): +### Anything else (e.g. bug report, performance optimization, typo correction) * Start coding right away. @@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr | Non-critical bugs, performance boosts | Medium Priority | | Minor fixes (typos, confusing but working UI) | Low Priority | - ## Installing Here are the steps to set up Dify for development: @@ -63,7 +62,7 @@ Here are the steps to set up Dify for development: Clone the forked repository from your terminal: -``` +```shell git clone git@github.com:/dify.git ``` @@ -71,11 +70,11 @@ git clone git@github.com:/dify.git Dify requires the following dependencies to build, make sure they're installed on your system: -- [Docker](https://www.docker.com/) -- [Docker Compose](https://docs.docker.com/compose/install/) -- [Node.js v18.x (LTS)](http://nodejs.org) -- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +* [Docker](https://www.docker.com/) +* [Docker Compose](https://docs.docker.com/compose/install/) +* [Node.js v18.x (LTS)](http://nodejs.org) +* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) +* [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. Installations @@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo ### 5. Visit dify in your browser -To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. +To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. ## Developing @@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou ### Backend -Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. +Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. -``` +```text [api/] ├── constants // Constant settings used throughout code base. ├── controllers // API route definitions and request handling logic. @@ -121,7 +120,7 @@ Dify’s backend is written in Python using [Flask](https://flask.palletsproject The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization. -``` +```text [web/] ├── app // layouts, pages, and components │ ├── (commonLayout) // common layout used throughout the app @@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ ## Submitting your PR -At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). +At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md). ## Getting Help -If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. +If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 310c55090a..4fa43b24ee 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -71,7 +71,7 @@ Dify 依赖以下工具和库: - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +- [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. 安装 diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index a68bdeddbc..22e30e9c03 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +- [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. インストール diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index a77239ff38..ad41f51aeb 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) phiên bản 3.10.x +- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x ### 4. Cài đặt @@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ ## Nhận trợ giúp -Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. \ No newline at end of file +Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. diff --git a/api/Dockerfile b/api/Dockerfile index 50d8ccd146..d9b0f2487e 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM python:3.10-slim-bookworm AS base +FROM python:3.12-slim-bookworm AS base WORKDIR /app/api diff --git a/api/README.md b/api/README.md index cfbc000bd8..461dac4759 100644 --- a/api/README.md +++ b/api/README.md @@ -42,7 +42,7 @@ 5. Install dependencies ```bash - poetry env use 3.10 + poetry env use 3.12 poetry install ``` @@ -81,5 +81,3 @@ ```bash poetry run -C api bash dev/pytest/pytest_all_tests.sh ``` - - diff --git a/api/app.py b/api/app.py index a667a84fd6..c1acb8bd0d 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,11 @@ import os import sys +python_version = sys.version_info +if not ((3, 11) <= python_version < (3, 13)): + print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}") + raise SystemExit(1) + from configs import dify_config if not dify_config.DEBUG: @@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no # DO NOT REMOVE ABOVE -if sys.version_info[:2] == (3, 10): - print("Warning: Python 3.10 will not be supported in the next version.") - warnings.simplefilter("ignore", ResourceWarning) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 61de73c868..07ef6121cc 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -27,7 +27,6 @@ class DifyConfig( # read from dotenv format config file env_file=".env", env_file_encoding="utf-8", - frozen=True, # ignore extra attributes extra="ignore", ) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 9d0dd3fb23..84e476661d 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,6 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi +from .app.app_import import AppImportApi, AppImportConfirmApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type") api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") +# Import App +api.add_resource(AppImportApi, "/apps/imports") +api.add_resource(AppImportConfirmApi, "/apps/imports//confirm") + # Import other controllers from . import admin, apikey, extension, feature, ping, setup, version diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4ed74e2b08..9687b59cd1 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,7 +1,10 @@ import uuid +from typing import cast from flask_login import current_user from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort from controllers.console import api @@ -12,15 +15,16 @@ from controllers.console.wraps import ( enterprise_license_required, setup_required, ) -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import OpsTraceManager +from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, ) from libs.login import login_required -from services.app_dsl_service import AppDslService +from models import Account, App +from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] @@ -93,99 +97,6 @@ class AppListApi(Resource): return app, 201 -class AppImportDependenciesCheckApi(Resource): - @setup_required - @login_required - @account_initialization_required - @cloud_edition_billing_resource_check("apps") - def post(self): - """Check dependencies""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("data", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - - leaked_dependencies = AppDslService.check_dependencies( - tenant_id=current_user.current_tenant_id, data=args["data"], account=current_user - ) - - return jsonable_encoder({"leaked": leaked_dependencies}), 200 - - -class AppImportApi(Resource): - @setup_required - @login_required - @account_initialization_required - @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check("apps") - def post(self): - """Import app""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("data", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - args = parser.parse_args() - - app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user - ) - - return app, 201 - - -class AppImportFromUrlApi(Resource): - @setup_required - @login_required - @account_initialization_required - @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check("apps") - def post(self): - """Import app from url""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("url", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - args = parser.parse_args() - - app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user - ) - - return app, 201 - - -class AppImportFromUrlDependenciesCheckApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument("url", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - - leaked_dependencies = AppDslService.check_dependencies_from_url( - tenant_id=current_user.current_tenant_id, url=args["url"], account=current_user - ) - - return jsonable_encoder({"leaked": leaked_dependencies}), 200 - - class AppApi(Resource): @setup_required @login_required @@ -263,10 +174,24 @@ class AppCopyApi(Resource): parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() - data = AppDslService.export_dsl(app_model=app_model, include_secret=True) - app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user - ) + with Session(db.engine) as session: + import_service = AppDslService(session) + yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) + account = cast(Account, current_user) + result = import_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=yaml_content, + name=args.get("name"), + description=args.get("description"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), + ) + session.commit() + + stmt = select(App).where(App.id == result.app.id) + app = session.scalar(stmt) return app, 201 @@ -407,10 +332,6 @@ class AppTraceApi(Resource): api.add_resource(AppListApi, "/apps") -api.add_resource(AppImportDependenciesCheckApi, "/apps/import/dependencies/check") -api.add_resource(AppImportApi, "/apps/import") -api.add_resource(AppImportFromUrlApi, "/apps/import/url") -api.add_resource(AppImportFromUrlDependenciesCheckApi, "/apps/import/url/dependencies/check") api.add_resource(AppApi, "/apps/") api.add_resource(AppCopyApi, "/apps//copy") api.add_resource(AppExportApi, "/apps//export") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py new file mode 100644 index 0000000000..244dcd75de --- /dev/null +++ b/api/controllers/console/app/app_import.py @@ -0,0 +1,90 @@ +from typing import cast + +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from extensions.ext_database import db +from fields.app_fields import app_import_fields +from libs.login import login_required +from models import Account +from services.app_dsl_service import AppDslService, ImportStatus + + +class AppImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_import_fields) + def post(self): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("mode", type=str, required=True, location="json") + parser.add_argument("yaml_content", type=str, location="json") + parser.add_argument("yaml_url", type=str, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("app_id", type=str, location="json") + args = parser.parse_args() + + # Create service with session + with Session(db.engine) as session: + import_service = AppDslService(session) + # Import app + account = cast(Account, current_user) + result = import_service.import_app( + account=account, + import_mode=args["mode"], + yaml_content=args.get("yaml_content"), + yaml_url=args.get("yaml_url"), + name=args.get("name"), + description=args.get("description"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), + app_id=args.get("app_id"), + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING.value: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +class AppImportConfirmApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_import_fields) + def post(self, import_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + # Create service with session + with Session(db.engine) as session: + import_service = AppDslService(session) + # Confirm import + account = cast(Account, current_user) + result = import_service.confirm_import(import_id=import_id, account=account) + session.commit() + + # Return appropriate status code based on result + if result.status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 7b78f622b9..a25004be4d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytz from flask_login import current_user @@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.read_at = datetime.now(UTC).replace(tzinfo=None) conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 2f5645852f..407f689819 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse @@ -75,7 +75,7 @@ class AppSite(Resource): setattr(site, attr_name, value) site.updated_by = current_user.id - site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return site @@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource): site.code = Site.generate_code(16) site.updated_by = current_user.id - site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return site diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 75354218c4..b5360e1b44 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -21,7 +21,6 @@ from libs.login import current_user, login_required from models import App from models.account import Account from models.model import AppMode -from services.app_dsl_service import AppDslService from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.workflow_service import WorkflowService @@ -130,34 +129,6 @@ class DraftWorkflowApi(Resource): } -class DraftWorkflowImportApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_fields) - def post(self, app_model: App): - """ - Import draft workflow - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - if not isinstance(current_user, Account): - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("data", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - - workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, data=args["data"], account=current_user - ) - - return workflow - - class AdvancedChatDraftWorkflowRunApi(Resource): @setup_required @login_required @@ -490,7 +461,6 @@ class ConvertToWorkflowApi(Resource): api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") -api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index be353cefac..d2aa7c903b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -65,7 +65,7 @@ class ActivateApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 2ee8060502..2162bd5c01 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional import requests @@ -108,7 +108,7 @@ class OAuthCallback(Resource): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 06fb3a0a31..89a638ae54 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -88,7 +88,7 @@ class DataSourceApi(Resource): if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: @@ -97,7 +97,7 @@ class DataSourceApi(Resource): if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 7105b417b7..ee0109f234 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,6 @@ import logging from argparse import ArgumentTypeError -from datetime import datetime, timezone +from datetime import UTC, datetime from flask import request from flask_login import current_user @@ -681,7 +681,7 @@ class DocumentProcessingApi(DocumentResource): raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id - document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.paused_at = datetime.now(UTC).replace(tzinfo=None) document.is_paused = True db.session.commit() @@ -761,7 +761,7 @@ class DocumentMetadataApi(DocumentResource): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return {"result": "success", "message": "Document metadata updated."}, 200 @@ -803,7 +803,7 @@ class DocumentStatusApi(DocumentResource): document.enabled = True document.disabled_at = None document.disabled_by = None - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -820,9 +820,9 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError("Document already disabled.") document.enabled = False - document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) document.disabled_by = current_user.id - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -837,9 +837,9 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError("Document already archived.") document.archived = True - document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.archived_at = datetime.now(UTC).replace(tzinfo=None) document.archived_by = current_user.id - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() if document.enabled: @@ -856,7 +856,7 @@ class DocumentStatusApi(DocumentResource): document.archived = False document.archived_at = None document.archived_by = None - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5d8d664e41..6f7ef86d2c 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime import pandas as pd from flask import request @@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource): raise InvalidActionError("Segment is already disabled.") segment.enabled = False - segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.commit() diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 125bc1af8c..85c43f8101 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import reqparse @@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: @@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource): args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index d72715a38c..b60c4e176b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse @@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), + last_used_at=datetime.now(UTC).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 750f65168f..f704783cff 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -60,7 +60,7 @@ class AccountInitApi(Resource): raise InvalidInvitationCodeError() invitation_code.status = "used" - invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -68,7 +68,7 @@ class AccountInitApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" - account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() return {"result": "success"} diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index b935b23ed6..2128c4c53f 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from functools import wraps from typing import Optional @@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None): if not api_token: raise Unauthorized("Access token is invalid") - api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return api_token diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index a22395b8e3..b9aae7904f 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,3 +1,4 @@ +import uuid from typing import Optional from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index c72c838d06..ecb045a30a 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -12,7 +12,7 @@ from core.provider_manager import ProviderManager class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -39,27 +39,23 @@ class ModelConfigConverter: ) if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, model_type=ModelType.LLM - ) + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, model_type=ModelType.LLM + ) - if provider_model is None: - model_name = model_config.model - raise ValueError(f"Model {model_name} not exist.") + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = model_config.parameters @@ -77,7 +73,7 @@ class ModelConfigConverter: if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value - if not skip_check and not model_schema: + if not model_schema: raise ValueError(f"Model {model_name} not exist.") return ModelConfigWithCredentialsEntity( diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 82a0e56ce8..fa30511f63 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,4 +1,5 @@ from core.app.app_config.entities import ( + AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, @@ -25,7 +26,9 @@ class PromptTemplateConfigManager: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): chat_prompt_messages.append( - {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + AdvancedChatMessageEntity( + **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 9b72452d7a..15bd353484 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None -class VariableEntityType(str, Enum): +class VariableEntityType(StrEnum): TEXT_INPUT = "text-input" SELECT = "select" PARAGRAPH = "paragraph" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4cb5e713e6..fbef787c98 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -138,7 +138,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 67b79c0589..1df517d859 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -139,7 +139,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 527aeed6a9..5535b565cb 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,5 +1,5 @@ import json -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import VariableEntityType @@ -7,7 +7,7 @@ from core.file import File, FileUploadConfig from factories import file_factory if TYPE_CHECKING: - from core.app.app_config.entities import AppConfig, VariableEntity + from core.app.app_config.entities import VariableEntity class BaseAppGenerator: @@ -15,23 +15,23 @@ class BaseAppGenerator: self, *, user_inputs: Optional[Mapping[str, Any]], - app_config: "AppConfig", + variables: Sequence["VariableEntity"], + tenant_id: str, ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values - variables = app_config.variables user_inputs = { var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) for var in variables } user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} # Convert files in inputs to File - entity_dictionary = {item.variable: item for item in app_config.variables} + entity_dictionary = {item.variable: item for item in variables} # Convert single file to File files_inputs = { k: file_factory.build_from_mapping( mapping=v, - tenant_id=app_config.tenant_id, + tenant_id=tenant_id, config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, @@ -45,7 +45,7 @@ class BaseAppGenerator: file_list_inputs = { k: file_factory.build_from_mappings( mappings=v, - tenant_id=app_config.tenant_id, + tenant_id=tenant_id, config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 50e15ea9dc..dce59cc09e 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -142,7 +142,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c903e6bf03..1678ebbfe2 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -123,7 +123,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id + ), query=query, files=file_objs, user_id=user.id, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index da206f01e7..95ae798ec1 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import and_ @@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.commit() db.session.refresh(conversation) else: - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() message = Message( diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 792cec7be6..a652d205d3 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -108,7 +108,9 @@ class WorkflowAppGenerator(BaseAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), files=system_files, user_id=user.id, stream=stream, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 2872390d46..1cf72ae79e 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType -from core.workflow.nodes.iteration import IterationNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner): user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get("data", {})), ) return graph, variable_pool diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 69bc0d7f9e..15543638fc 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, field_validator @@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNodeData -class QueueEvent(str, Enum): +class QueueEvent(StrEnum): """ QueueEvent enum """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 46b8609277..9229cbcc0a 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,7 +1,7 @@ import json import time from collections.abc import Mapping, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union, cast from sqlalchemy.orm import Session @@ -144,7 +144,7 @@ class WorkflowCycleManage: workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) @@ -191,7 +191,7 @@ class WorkflowCycleManage: workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -211,15 +211,18 @@ class WorkflowCycleManage: for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() db.session.commit() - db.session.refresh(workflow_run) db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + session.add(workflow_run) + session.refresh(workflow_run) + if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -259,7 +262,7 @@ class WorkflowCycleManage: NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, } ) - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) session.add(workflow_node_execution) session.commit() @@ -282,7 +285,7 @@ class WorkflowCycleManage: execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( @@ -326,7 +329,7 @@ class WorkflowCycleManage: inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None @@ -654,7 +657,7 @@ class WorkflowCycleManage: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 534e00fdd9..f764c43098 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -246,7 +246,7 @@ class ProviderConfiguration(BaseModel): if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: provider_record = Provider() @@ -401,7 +401,7 @@ class ProviderConfiguration(BaseModel): if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: provider_model_record = ProviderModel() @@ -474,7 +474,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting() @@ -508,7 +508,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting() @@ -574,7 +574,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting() @@ -608,7 +608,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting() diff --git a/api/core/file/enums.py b/api/core/file/enums.py index f4153f1676..06b99d3eb0 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class FileType(str, Enum): +class FileType(StrEnum): IMAGE = "image" DOCUMENT = "document" AUDIO = "audio" @@ -16,7 +16,7 @@ class FileType(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileTransferMethod(str, Enum): +class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" @@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileBelongsTo(str, Enum): +class FileBelongsTo(StrEnum): USER = "user" ASSISTANT = "assistant" @@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileAttribute(str, Enum): +class FileAttribute(StrEnum): TYPE = "type" SIZE = "size" NAME = "name" @@ -51,5 +51,5 @@ class FileAttribute(str, Enum): EXTENSION = "extension" -class ArrayFileAttribute(str, Enum): +class ArrayFileAttribute(StrEnum): LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index eb260a8f84..6d8086435d 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -3,7 +3,12 @@ import base64 from configs import dify_config from core.file import file_repository from core.helper import ssrf_proxy -from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent +from core.model_runtime.entities import ( + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + VideoPromptMessageContent, +) from extensions.ext_database import db from extensions.ext_storage import storage @@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute): return file.remote_url case FileAttribute.EXTENSION: return file.extension - case _: - raise ValueError(f"Invalid file attribute: {attr}") def to_prompt_message_content( f: File, /, *, - image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ): - """ - Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - - This function takes a File object and converts it to an appropriate PromptMessageContent - object, which can be used as a prompt for image or audio-based AI models. - - Args: - f (File): The File object to convert. - detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. - If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. - - Returns: - Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level - - Raises: - ValueError: If the file type is not supported or if required data is missing. - """ match f.type: case FileType.IMAGE: + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": data = _to_url(f) else: @@ -65,7 +52,7 @@ def to_prompt_message_content( return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: - encoded_string = _file_to_encoded_string(f) + encoded_string = _get_encoded_string(f) if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) @@ -74,9 +61,20 @@ def to_prompt_message_content( data = _to_url(f) else: data = _to_base64_data_string(f) + if f.extension is None: + raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) + case FileType.DOCUMENT: + data = _get_encoded_string(f) + if f.mime_type is None: + raise ValueError("Missing file mime_type") + return DocumentPromptMessageContent( + encode_format="base64", + mime_type=f.mime_type, + data=data, + ) case _: - raise ValueError("file type f.type is not supported") + raise ValueError(f"file type {f.type} is not supported") def download(f: File, /): @@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /): case FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() - content = response.content - encoded_string = base64.b64encode(content).decode("utf-8") - return encoded_string + data = response.content case FileTransferMethod.LOCAL_FILE: upload_file = file_repository.get_upload_file(session=db.session(), file=f) data = _download_file_content(upload_file.key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string case FileTransferMethod.TOOL_FILE: tool_file = file_repository.get_tool_file(session=db.session(), file=f) data = _download_file_content(tool_file.file_key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string - case _: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string def _to_base64_data_string(f: File, /): @@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /): return f"data:{f.mime_type};base64,{encoded_string}" -def _file_to_encoded_string(f: File, /): - match f.type: - case FileType.IMAGE: - return _to_base64_data_string(f) - case FileType.VIDEO: - return _to_base64_data_string(f) - case FileType.AUDIO: - return _get_encoded_string(f) - case _: - raise ValueError(f"file type {f.type} is not supported") - - def _to_url(f: File, /): if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 03c4b8d49d..011ff382ea 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from enum import Enum +from enum import StrEnum from threading import Lock from typing import Any, Optional @@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel): data: Data -class CodeLanguage(str, Enum): +class CodeLanguage(StrEnum): PYTHON3 = "python3" JINJA2 = "jinja2" JAVASCRIPT = "javascript" diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 422dff172f..388cbc0be3 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -86,7 +86,7 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: logging.warning("Document deleted, document id: {}".format(dataset_document.id)) @@ -94,7 +94,7 @@ class IndexingRunner: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def run_in_splitting_status(self, dataset_document: DatasetDocument): @@ -142,13 +142,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): @@ -200,13 +200,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def indexing_estimate( @@ -372,7 +372,7 @@ class IndexingRunner: after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) @@ -464,7 +464,7 @@ class IndexingRunner: doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -479,7 +479,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) @@ -680,7 +680,7 @@ class IndexingRunner: after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, - DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, }, @@ -705,7 +705,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) @@ -738,7 +738,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) @@ -849,7 +849,7 @@ class IndexingRunner: doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -864,7 +864,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) pass diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 688fb4776a..81d08dc885 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,8 +1,8 @@ +from collections.abc import Sequence from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager -from core.file.models import FileType from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, @@ -27,7 +27,7 @@ class TokenBufferMemory: def get_history_prompt_messages( self, max_token_limit: int = 2000, message_limit: Optional[int] = None - ) -> list[PromptMessage]: + ) -> Sequence[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -102,12 +102,11 @@ class TokenBufferMemory: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: - if file.type in {FileType.IMAGE, FileType.AUDIO}: - prompt_message = file_manager.to_prompt_message_content( - file, - image_detail_config=detail, - ) - prompt_message_contents.append(prompt_message) + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) + prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index c12134a97d..fc7ae18783 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -136,10 +136,10 @@ class ModelInstance: def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 6bd9325785..8870b34435 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -31,7 +32,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -60,7 +61,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -90,7 +91,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -120,7 +121,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index f5d4427e3e..5e52f10b4c 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa from .message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContent, @@ -37,4 +38,5 @@ __all__ = [ "LLMResultChunk", "LLMResultChunkDelta", "AudioPromptMessageContent", + "DocumentPromptMessageContent", ] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3c244d368e..f2870209bb 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,6 +1,7 @@ from abc import ABC -from enum import Enum -from typing import Optional +from collections.abc import Sequence +from enum import Enum, StrEnum +from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel): function: PromptMessageTool -class PromptMessageContentType(Enum): +class PromptMessageContentType(StrEnum): """ Enum class for prompt message content type. """ @@ -57,6 +58,7 @@ class PromptMessageContentType(Enum): IMAGE = "image" AUDIO = "audio" VIDEO = "video" + DOCUMENT = "document" class PromptMessageContent(BaseModel): @@ -93,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent): Model class for image prompt message content. """ - class DETAIL(str, Enum): + class DETAIL(StrEnum): LOW = "low" HIGH = "high" @@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent): detail: DETAIL = DETAIL.LOW +class DocumentPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.DOCUMENT + encode_format: Literal["base64"] + mime_type: str + data: str + + class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContent]] = None + content: Optional[str | Sequence[PromptMessageContent]] = None name: Optional[str] = None def is_empty(self) -> bool: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 09b9e3aa0d..3225f03fbd 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,5 +1,5 @@ from decimal import Decimal -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, ConfigDict @@ -82,9 +82,12 @@ class ModelFeature(Enum): AGENT_THOUGHT = "agent-thought" VISION = "vision" STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" -class DefaultParameterName(str, Enum): +class DefaultParameterName(StrEnum): """ Enum class for parameter template variable. """ diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 18a5e8be34..5447058ad7 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,6 +1,6 @@ import logging import time -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -41,7 +41,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -96,7 +96,7 @@ class LargeLanguageModel(AIModel): model_parameters=model_parameters, prompt_messages=prompt_messages, tools=tools, - stop=stop, + stop=list(stop) if stop else None, stream=stream, ) @@ -176,7 +176,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -318,7 +318,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -364,7 +364,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -411,7 +411,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -459,7 +459,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png deleted file mode 100644 index dfe8e78049..0000000000 Binary files a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png and /dev/null differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg deleted file mode 100644 index bb23bffcf1..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png deleted file mode 100644 index b154821db9..0000000000 Binary files a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png and /dev/null differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg deleted file mode 100644 index c5c608cd7c..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - - - - - diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.py b/api/core/model_runtime/model_providers/gpustack/gpustack.py deleted file mode 100644 index 321100167e..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/gpustack.py +++ /dev/null @@ -1,10 +0,0 @@ -import logging - -from core.model_runtime.model_providers.__base.model_provider import ModelProvider - -logger = logging.getLogger(__name__) - - -class GPUStackProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml deleted file mode 100644 index ee4a3c159a..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml +++ /dev/null @@ -1,120 +0,0 @@ -provider: gpustack -label: - en_US: GPUStack -icon_small: - en_US: icon_s_en.png -icon_large: - en_US: icon_l_en.png -supported_model_types: - - llm - - text-embedding - - rerank -configurate_methods: - - customizable-model -model_credential_schema: - model: - label: - en_US: Model Name - zh_Hans: 模型名称 - placeholder: - en_US: Enter your model name - zh_Hans: 输入模型名称 - credential_form_schemas: - - variable: endpoint_url - label: - zh_Hans: 服务器地址 - en_US: Server URL - type: text-input - required: true - placeholder: - zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100 - en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100 - - variable: api_key - label: - en_US: API Key - type: secret-input - required: true - placeholder: - zh_Hans: 输入您的 API Key - en_US: Enter your API Key - - variable: mode - show_on: - - variable: __model_type - value: llm - label: - en_US: Completion mode - type: select - required: false - default: chat - placeholder: - zh_Hans: 选择补全类型 - en_US: Select completion type - options: - - value: completion - label: - en_US: Completion - zh_Hans: 补全 - - value: chat - label: - en_US: Chat - zh_Hans: 对话 - - variable: context_size - label: - zh_Hans: 模型上下文长度 - en_US: Model context size - required: true - type: text-input - default: "8192" - placeholder: - zh_Hans: 输入您的模型上下文长度 - en_US: Enter your Model context size - - variable: max_tokens_to_sample - label: - zh_Hans: 最大 token 上限 - en_US: Upper bound for max tokens - show_on: - - variable: __model_type - value: llm - default: "8192" - type: text-input - - variable: function_calling_type - show_on: - - variable: __model_type - value: llm - label: - en_US: Function calling - type: select - required: false - default: no_call - options: - - value: function_call - label: - en_US: Function Call - zh_Hans: Function Call - - value: tool_call - label: - en_US: Tool Call - zh_Hans: Tool Call - - value: no_call - label: - en_US: Not Support - zh_Hans: 不支持 - - variable: vision_support - show_on: - - variable: __model_type - value: llm - label: - zh_Hans: Vision 支持 - en_US: Vision Support - type: select - required: false - default: no_support - options: - - value: support - label: - en_US: Support - zh_Hans: 支持 - - value: no_support - label: - en_US: Not Support - zh_Hans: 不支持 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/__init__.py b/api/core/model_runtime/model_providers/gpustack/llm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py deleted file mode 100644 index ce6780b6a7..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/llm/llm.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections.abc import Generator - -from yarl import URL - -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import ( - PromptMessage, - PromptMessageTool, -) -from core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( - OAIAPICompatLargeLanguageModel, -) - - -class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, - ) -> LLMResult | Generator: - return super()._invoke( - model, - credentials, - prompt_messages, - model_parameters, - tools, - stop, - stream, - user, - ) - - def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) - - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") - credentials["mode"] = "chat" diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py b/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py deleted file mode 100644 index 5ea7532564..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py +++ /dev/null @@ -1,146 +0,0 @@ -from json import dumps -from typing import Optional - -import httpx -from requests import post -from yarl import URL - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - FetchFrom, - ModelPropertyKey, - ModelType, -) -from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.rerank_model import RerankModel - - -class GPUStackRerankModel(RerankModel): - """ - Model class for GPUStack rerank model. - """ - - def _invoke( - self, - model: str, - credentials: dict, - query: str, - docs: list[str], - score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None, - ) -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n documents to return - :param user: unique user id - :return: rerank result - """ - if len(docs) == 0: - return RerankResult(model=model, docs=[]) - - endpoint_url = credentials["endpoint_url"] - headers = { - "Authorization": f"Bearer {credentials.get('api_key')}", - "Content-Type": "application/json", - } - - data = {"model": model, "query": query, "documents": docs, "top_n": top_n} - - try: - response = post( - str(URL(endpoint_url) / "v1" / "rerank"), - headers=headers, - data=dumps(data), - timeout=10, - ) - response.raise_for_status() - results = response.json() - - rerank_documents = [] - for result in results["results"]: - index = result["index"] - if "document" in result: - text = result["document"]["text"] - else: - text = docs[index] - - rerank_document = RerankDocument( - index=index, - text=text, - score=result["relevance_score"], - ) - - if score_threshold is None or result["relevance_score"] >= score_threshold: - rerank_documents.append(rerank_document) - - return RerankResult(model=model, docs=rerank_documents) - except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) - - def validate_credentials(self, model: str, credentials: dict) -> None: - """ - Validate model credentials - - :param model: model name - :param credentials: model credentials - :return: - """ - try: - self._invoke( - model=model, - credentials=credentials, - query="What is the capital of the United States?", - docs=[ - "Carson City is the capital city of the American state of Nevada. At the 2010 United States " - "Census, Carson City had a population of 55,274.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " - "are a political division controlled by the United States. Its capital is Saipan.", - ], - score_threshold=0.8, - ) - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - """ - Map model invoke error to unified error - """ - return { - InvokeConnectionError: [httpx.ConnectError], - InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError], - } - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - """ - generate custom model entities from credentials - """ - entity = AIModelEntity( - model=model, - label=I18nObject(en_US=model), - model_type=ModelType.RERANK, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, - ) - - return entity diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py deleted file mode 100644 index eb324491a2..0000000000 --- a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional - -from yarl import URL - -from core.entities.embedding_type import EmbeddingInputType -from core.model_runtime.entities.text_embedding_entities import ( - TextEmbeddingResult, -) -from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( - OAICompatEmbeddingModel, -) - - -class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel): - """ - Model class for GPUStack text embedding model. - """ - - def _invoke( - self, - model: str, - credentials: dict, - texts: list[str], - user: Optional[str] = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: - return super()._invoke(model, credentials, texts, user, input_type) - - def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) - - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml deleted file mode 100644 index 37b9f30cc3..0000000000 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml +++ /dev/null @@ -1,55 +0,0 @@ -model: claude-3-5-sonnet-v2@20241022 -label: - en_US: Claude 3.5 Sonnet v2 -model_type: llm -features: - - agent-thought - - vision -model_properties: - mode: chat - context_size: 200000 -parameter_rules: - - name: max_tokens - use_template: max_tokens - required: true - type: int - default: 8192 - min: 1 - max: 8192 - help: - zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 - en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. - - name: temperature - use_template: temperature - required: false - type: float - default: 1 - min: 0.0 - max: 1.0 - help: - zh_Hans: 生成内容的随机性。 - en_US: The amount of randomness injected into the response. - - name: top_p - required: false - type: float - default: 0.999 - min: 0.000 - max: 1.000 - help: - zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 - en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. - - name: top_k - required: false - type: int - default: 0 - min: 0 - # tip docs from aws has error, max value is 500 - max: 500 - help: - zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 - en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. -pricing: - input: '0.003' - output: '0.015' - unit: '0.001' - currency: USD diff --git a/api/core/model_runtime/model_providers/vessl_ai/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png deleted file mode 100644 index 18ba350fa0..0000000000 Binary files a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png and /dev/null differ diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg deleted file mode 100644 index 242f4e82b2..0000000000 --- a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py deleted file mode 100644 index 034c066ab5..0000000000 --- a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py +++ /dev/null @@ -1,83 +0,0 @@ -from decimal import Decimal - -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - FetchFrom, - ModelPropertyKey, - ModelType, - ParameterRule, - ParameterType, - PriceConfig, -) -from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel - - -class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - features = [] - - entity = AIModelEntity( - model=model, - label=I18nObject(en_US=model), - model_type=ModelType.LLM, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features, - model_properties={ - ModelPropertyKey.MODE: credentials.get("mode"), - }, - parameter_rules=[ - ParameterRule( - name=DefaultParameterName.TEMPERATURE.value, - label=I18nObject(en_US="Temperature"), - type=ParameterType.FLOAT, - default=float(credentials.get("temperature", 0.7)), - min=0, - max=2, - precision=2, - ), - ParameterRule( - name=DefaultParameterName.TOP_P.value, - label=I18nObject(en_US="Top P"), - type=ParameterType.FLOAT, - default=float(credentials.get("top_p", 1)), - min=0, - max=1, - precision=2, - ), - ParameterRule( - name=DefaultParameterName.TOP_K.value, - label=I18nObject(en_US="Top K"), - type=ParameterType.INT, - default=int(credentials.get("top_k", 50)), - min=-2147483647, - max=2147483647, - precision=0, - ), - ParameterRule( - name=DefaultParameterName.MAX_TOKENS.value, - label=I18nObject(en_US="Max Tokens"), - type=ParameterType.INT, - default=512, - min=1, - max=int(credentials.get("max_tokens_to_sample", 4096)), - ), - ], - pricing=PriceConfig( - input=Decimal(credentials.get("input_price", 0)), - output=Decimal(credentials.get("output_price", 0)), - unit=Decimal(credentials.get("unit", 0)), - currency=credentials.get("currency", "USD"), - ), - ) - - if credentials["mode"] == "chat": - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials["mode"] == "completion": - entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value - else: - raise ValueError(f"Unknown completion type {credentials['completion_type']}") - - return entity diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py deleted file mode 100644 index 7a987c6710..0000000000 --- a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py +++ /dev/null @@ -1,10 +0,0 @@ -import logging - -from core.model_runtime.model_providers.__base.model_provider import ModelProvider - -logger = logging.getLogger(__name__) - - -class VesslAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml deleted file mode 100644 index 2138b281b9..0000000000 --- a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml +++ /dev/null @@ -1,56 +0,0 @@ -provider: vessl_ai -label: - en_US: VESSL AI -icon_small: - en_US: icon_s_en.svg -icon_large: - en_US: icon_l_en.png -background: "#F1EFED" -help: - title: - en_US: How to deploy VESSL AI LLM Model Endpoint - url: - en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment -supported_model_types: - - llm -configurate_methods: - - customizable-model -model_credential_schema: - model: - label: - en_US: Model Name - placeholder: - en_US: Enter model name - credential_form_schemas: - - variable: endpoint_url - label: - en_US: Endpoint Url - type: text-input - required: true - placeholder: - en_US: Enter VESSL AI service endpoint url - - variable: api_key - required: true - label: - en_US: API Key - type: secret-input - placeholder: - en_US: Enter VESSL AI secret key - - variable: mode - show_on: - - variable: __model_type - value: llm - label: - en_US: Completion Mode - type: select - required: false - default: chat - placeholder: - en_US: Select completion mode - options: - - value: completion - label: - en_US: Completion - - value: chat - label: - en_US: Chat diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg deleted file mode 100644 index f8b745cb13..0000000000 --- a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml deleted file mode 100644 index 7c305735b9..0000000000 --- a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml +++ /dev/null @@ -1,63 +0,0 @@ -model: grok-beta -label: - en_US: Grok beta -model_type: llm -features: - - multi-tool-call -model_properties: - mode: chat - context_size: 131072 -parameter_rules: - - name: temperature - label: - en_US: "Temperature" - zh_Hans: "采样温度" - type: float - default: 0.7 - min: 0.0 - max: 2.0 - precision: 1 - required: true - help: - en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." - zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" - - - name: top_p - label: - en_US: "Top P" - zh_Hans: "Top P" - type: float - default: 0.7 - min: 0.0 - max: 1.0 - precision: 1 - required: true - help: - en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." - zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" - - - name: frequency_penalty - use_template: frequency_penalty - label: - en_US: "Frequency Penalty" - zh_Hans: "频率惩罚" - type: float - default: 0 - min: 0 - max: 2.0 - precision: 1 - required: false - help: - en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." - zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" - - - name: user - use_template: text - label: - en_US: "User" - zh_Hans: "用户" - type: string - required: false - help: - en_US: "Used to track and differentiate conversation requests from different users." - zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py deleted file mode 100644 index 3f5325a857..0000000000 --- a/api/core/model_runtime/model_providers/x/llm/llm.py +++ /dev/null @@ -1,37 +0,0 @@ -from collections.abc import Generator -from typing import Optional, Union - -from yarl import URL - -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult -from core.model_runtime.entities.message_entities import ( - PromptMessage, - PromptMessageTool, -) -from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel - - -class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - self._add_custom_parameters(credentials) - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) - - def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) - - @staticmethod - def _add_custom_parameters(credentials) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" - credentials["mode"] = LLMMode.CHAT.value - credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py deleted file mode 100644 index e3f2b8eeba..0000000000 --- a/api/core/model_runtime/model_providers/x/x.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging - -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.model_provider import ModelProvider - -logger = logging.getLogger(__name__) - - -class XAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - """ - Validate provider credentials - if validate failed, raise exception - - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - """ - try: - model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="grok-beta", credentials=credentials) - except CredentialsValidateFailedError as ex: - raise ex - except Exception as ex: - logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") - raise ex diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml deleted file mode 100644 index 90d1cbfe7e..0000000000 --- a/api/core/model_runtime/model_providers/x/x.yaml +++ /dev/null @@ -1,38 +0,0 @@ -provider: x -label: - en_US: xAI -description: - en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. -icon_small: - en_US: x-ai-logo.svg -icon_large: - en_US: x-ai-logo.svg -help: - title: - en_US: Get your token from xAI - zh_Hans: 从 xAI 获取 token - url: - en_US: https://x.ai/api -supported_model_types: - - llm -configurate_methods: - - predefined-model -provider_credential_schema: - credential_form_schemas: - - variable: api_key - label: - en_US: API Key - type: secret-input - required: true - placeholder: - zh_Hans: 在此输入您的 API Key - en_US: Enter your API Key - - variable: endpoint_url - label: - en_US: API Base - type: text-input - required: false - default: https://api.x.ai/v1 - placeholder: - zh_Hans: 在此输入您的 API Base - en_US: Enter your API Base diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 256595286f..71ff03b6ef 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -122,7 +122,7 @@ trace_info_info_map = { } -class TraceTaskName(str, Enum): +class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index 447b799f1f..f486da3a6d 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -39,7 +39,7 @@ def validate_input_output(v, field_name): return v -class LevelEnum(str, Enum): +class LevelEnum(StrEnum): DEBUG = "DEBUG" WARNING = "WARNING" ERROR = "ERROR" @@ -178,7 +178,7 @@ class LangfuseSpan(BaseModel): return validate_input_output(v, field_name) -class UnitEnum(str, Enum): +class UnitEnum(StrEnum): CHARACTERS = "CHARACTERS" TOKENS = "TOKENS" SECONDS = "SECONDS" diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 16c76f363c..99221d669b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -8,7 +8,7 @@ from pydantic_core.core_schema import ValidationInfo from core.ops.utils import replace_text_with_content -class LangSmithRunType(str, Enum): +class LangSmithRunType(StrEnum): tool = "tool" chain = "chain" llm = "llm" diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 5a3481b963..93dd92f188 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(str, enum.Enum): +class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5eec5e3c99..aa175153bc 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import cast from core.model_runtime.entities import ( @@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: """ Prompt messages to prompt for saving. :param model_mode: model mode diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index 3c6ab2e4cf..754b0d18b7 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -12,7 +12,7 @@ class CleanProcessor: # Unicode U+FFFE text = re.sub("\ufffe", "", text) - rules = process_rule["rules"] if process_rule else None + rules = process_rule["rules"] if process_rule else {} if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py index d6deba3fb0..d845c7111d 100644 --- a/api/core/rag/datasource/keyword/keyword_type.py +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -1,5 +1,5 @@ -from enum import Enum +from enum import StrEnum -class KeyWordType(str, Enum): +class KeyWordType(StrEnum): JIEBA = "jieba" diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 8e53e3ae84..05183c0371 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class VectorType(str, Enum): +class VectorType(StrEnum): ANALYTICDB = "analyticdb" CHROMA = "chroma" MILVUS = "milvus" diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 313bdce48b..b23da1113e 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -114,10 +114,10 @@ class WordExtractor(BaseExtractor): mime_type=mime_type or "", created_by=self.user_id, created_by_role=CreatedByRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) db.session.add(upload_file) diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fc82b2080b..6ae432a526 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -27,11 +27,11 @@ class RerankModelRunner(BaseRerankRunner): :return: """ docs = [] - doc_id = set() + doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: - doc_id.add(document.metadata["doc_id"]) + if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) elif document.provider == "external": diff --git a/api/core/rag/rerank/rerank_type.py b/api/core/rag/rerank/rerank_type.py index d71eb2daa8..b2d1314654 100644 --- a/api/core/rag/rerank/rerank_type.py +++ b/api/core/rag/rerank/rerank_type.py @@ -1,6 +1,6 @@ -from enum import Enum +from enum import StrEnum -class RerankMode(str, Enum): +class RerankMode(StrEnum): RERANKING_MODEL = "reranking_model" WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index b706f29bb1..4719be012f 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -37,11 +37,10 @@ class WeightRerankRunner(BaseRerankRunner): :return: """ unique_documents = [] - doc_id = set() + doc_ids = set() for document in documents: - doc_id = document.metadata.get("doc_id") - if doc_id not in doc_id: - doc_id.add(doc_id) + if document.metadata["doc_id"] not in doc_ids: + doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) documents = unique_documents diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index 8c7daa7a69..185de9a0d4 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union from pytz import timezone as pytz_timezone @@ -23,7 +23,7 @@ class CurrentTimeTool(BuiltinTool): tz = tool_parameters.get("timezone", "UTC") fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" if tz == "UTC": - return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + return self.create_text_message(f"{datetime.now(UTC).strftime(fm)}") try: tz = pytz_timezone(tz) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index e59f92612c..6e8137a8e9 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,7 +1,7 @@ import json from collections.abc import Generator, Iterable from copy import deepcopy -from datetime import datetime, timezone +from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Optional, Union, cast @@ -64,7 +64,12 @@ class ToolEngine: if parameters and len(parameters) == 1: tool_parameters = {parameters[0].name: tool_parameters} else: - raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + try: + tool_parameters = json.loads(tool_parameters) + except Exception as e: + pass + if not isinstance(tool_parameters, dict): + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: @@ -195,10 +200,7 @@ class ToolEngine: """ Invoke the tool with the given arguments. """ - if not tool.runtime: - raise ValueError("missing runtime in tool") - - started_at = datetime.now(timezone.utc) + started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, error=None, @@ -216,7 +218,7 @@ class ToolEngine: meta.error = str(e) raise ToolEngineInvokeError(meta) finally: - ended_at = datetime.now(timezone.utc) + ended_at = datetime.now(UTC) meta.time_cost = (ended_at - started_at).total_seconds() yield meta diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index b71882b043..69bd5567a4 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -118,11 +118,11 @@ class FileSegment(Segment): @property def log(self) -> str: - return str(self.value) + return "" @property def text(self) -> str: - return str(self.value) + return "" class ArrayAnySegment(ArraySegment): @@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment): for item in self.value: items.append(item.markdown) return "\n".join(items) + + @property + def log(self) -> str: + return "" + + @property + def text(self) -> str: + return "" diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 53c2e8a3aa..af6a2a2937 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class SegmentType(str, Enum): +class SegmentType(StrEnum): NONE = "none" NUMBER = "number" STRING = "string" diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index a747266661..e174d3baa0 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel @@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from models.workflow import WorkflowNodeExecutionStatus -class NodeRunMetadataKey(str, Enum): +class NodeRunMetadataKey(StrEnum): """ Node Run Metadata Key. """ @@ -36,7 +36,7 @@ class NodeRunResult(BaseModel): inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[dict[str, Any]] = None # process data - outputs: Optional[dict[str, Any]] = None # node outputs + outputs: Optional[Mapping[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 213ed57f57..9642efa1a5 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class SystemVariableKey(str, Enum): +class SystemVariableKey(StrEnum): """ System Variables. """ diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index bb24b51112..baeec9bf01 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from typing import Optional @@ -63,7 +63,7 @@ class RouteNodeState(BaseModel): raise Exception(f"Invalid route status {run_result.status}") self.node_run_result = run_result - self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + self.finished_at = datetime.now(UTC).replace(tzinfo=None) class RuntimeRouteState(BaseModel): @@ -81,7 +81,7 @@ class RuntimeRouteState(BaseModel): :param node_id: node id """ - state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) self.node_state_mapping[state.id] = state return state diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 208144655b..9e9e52910e 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class NodeType(str, Enum): +class NodeType(StrEnum): START = "start" END = "end" ANSWER = "answer" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 80b322b068..22ad2a39f6 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -108,7 +108,7 @@ class Executor: self.content = self.variable_pool.convert_template(data[0].value).text case "json": json_string = self.variable_pool.convert_template(data[0].value).text - json_object = json.loads(json_string) + json_object = json.loads(json_string, strict=False) self.json = json_object # self.json = self._parse_object_contains_variables(json_object) case "binary": diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index ebcb6f82fb..7a489dd725 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import Field @@ -6,7 +6,7 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData -class ErrorHandleMode(str, Enum): +class ErrorHandleMode(StrEnum): TERMINATED = "terminated" CONTINUE_ON_ERROR = "continue-on-error" REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index d5428f0286..22f242a42f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -2,7 +2,7 @@ import logging import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait -from datetime import datetime, timezone +from datetime import UTC, datetime from queue import Empty, Queue from typing import TYPE_CHECKING, Any, Optional, cast @@ -135,7 +135,7 @@ class IterationNode(BaseNode[IterationNodeData]): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(timezone.utc).replace(tzinfo=None) + start_at = datetime.now(UTC).replace(tzinfo=None) yield IterationRunStartedEvent( iteration_id=self.id, @@ -367,7 +367,7 @@ class IterationNode(BaseNode[IterationNodeData]): """ run single iteration """ - iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None) + iter_start_at = datetime.now(UTC).replace(tzinfo=None) try: rst = graph_engine.run() @@ -440,7 +440,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -461,7 +461,7 @@ class IterationNode(BaseNode[IterationNodeData]): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -503,7 +503,7 @@ class IterationNode(BaseNode[IterationNodeData]): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index a25d563fe0..19a66087f7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -39,7 +39,14 @@ class VisionConfig(BaseModel): class PromptConfig(BaseModel): - jinja2_variables: Optional[list[VariableSelector]] = None + jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) + + @field_validator("jinja2_variables", mode="before") + @classmethod + def convert_none_jinja2_variables(cls, v: Any): + if v is None: + return [] + return v class LLMNodeChatModelMessage(ChatModelMessage): @@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: Optional[PromptConfig] = None + prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + + @field_validator("prompt_config", mode="before") + @classmethod + def convert_none_prompt_config(cls, v: Any): + if v is None: + return PromptConfig() + return v diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index f858be2515..6599221691 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -24,3 +24,17 @@ class LLMModeRequiredError(LLMNodeError): class NoPromptFoundError(LLMNodeError): """Raised when no prompt is found in the LLM configuration.""" + + +class TemplateTypeNotSupportError(LLMNodeError): + def __init__(self, *, type_name: str): + super().__init__(f"Prompt type {type_name} is not supported.") + + +class MemoryRolePrefixRequiredError(LLMNodeError): + """Raised when memory role prefix is required for completion model.""" + + +class FileTypeNotSupportError(LLMNodeError): + def __init__(self, *, type_name: str): + super().__init__(f"{type_name} type is not supported by this model") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index eb4d1c9d87..4a6f0ecae9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast @@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file import FileType, file_manager +from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( - AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, - VideoPromptMessageContent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables import ( @@ -34,6 +40,8 @@ from core.variables import ( ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode @@ -58,18 +66,23 @@ from .entities import ( ModelConfig, ) from .exc import ( + FileTypeNotSupportError, InvalidContextStructureError, InvalidVariableTypeError, LLMModeRequiredError, LLMNodeError, + MemoryRolePrefixRequiredError, ModelNotExistError, NoPromptFoundError, + TemplateTypeNotSupportError, VariableNotFoundError, ) if TYPE_CHECKING: from core.file.models import File +logger = logging.getLogger(__name__) + class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData @@ -121,19 +134,19 @@ class LLMNode(BaseNode[LLMNodeData]): # fetch memory memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) - # fetch prompt messages + query = None if self.node_data.memory: - query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - if not query: - raise VariableNotFoundError("Query not found") - query = query.text - else: - query = None + query = self.node_data.memory.query_prompt_template + if query is None and ( + query_variable := self.graph_runtime_state.variable_pool.get( + (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY) + ) + ): + query = query_variable.text prompt_messages, stop = self._fetch_prompt_messages( - system_query=query, - inputs=inputs, - files=files, + user_query=query, + user_files=files, context=context, memory=memory, model_config=model_config, @@ -141,6 +154,8 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) process_data = { @@ -181,6 +196,17 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) return + except Exception as e: + logger.exception(f"Node {self.node_id} failed to run") + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -203,8 +229,8 @@ class LLMNode(BaseNode[LLMNodeData]): self, node_data_model: ModelConfig, model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, + prompt_messages: Sequence[PromptMessage], + stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() @@ -519,9 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_prompt_messages( self, *, - system_query: str | None = None, - inputs: dict[str, str] | None = None, - files: Sequence["File"], + user_query: str | None = None, + user_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -529,58 +554,144 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - inputs = inputs or {} + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], + ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + prompt_messages = [] - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs=inputs, - query=system_query or "", - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - ) - stop = model_config.stop + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + _handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if user_query: + message = LLMNodeChatModelMessage( + text=user_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + _handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + + # Add current query to the prompt message + if user_query: + prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_messages[0].content = prompt_content + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + if vision_enabled and user_files: + file_prompts = [] + for file in user_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Filter prompt messages filtered_prompt_messages = [] for prompt_message in prompt_messages: - if prompt_message.is_empty(): - continue - - if not isinstance(prompt_message.content, str): + if isinstance(prompt_message.content, list): prompt_message_content = [] - for content_item in prompt_message.content or []: - # Skip image if vision is disabled - if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_config.model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) continue - if isinstance(content_item, ImagePromptMessageContent): - # Override vision config if LLM node has vision config, - # cuz vision detail is related to the configuration from FileUpload feature. - content_item.detail = vision_detail - prompt_message_content.append(content_item) - elif isinstance( - content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_config.model_schema.features + ) ): - prompt_message_content.append(content_item) - - if len(prompt_message_content) > 1: - prompt_message.content = prompt_message_content - elif ( - len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT - ): + raise FileTypeNotSupportError(type_name=content_item.type) + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data - + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue filtered_prompt_messages.append(prompt_message) - if not filtered_prompt_messages: + if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) + stop = model_config.stop return filtered_prompt_messages, stop @classmethod @@ -715,3 +826,198 @@ class LLMNode(BaseNode[LLMNodeData]): } }, } + + +def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinjia2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +): + if not template: + return "" + + jinjia2_inputs = {} + for jinja2_variable in jinjia2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + code_execute_resp = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=jinjia2_inputs, + ) + result_text = code_execute_resp["result"] + return result_text + + +def _handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, +) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = UserPromptMessage(content=file_contents) + prompt_messages.append(prompt_message) + + return prompt_messages + + +def _calculate_rest_token( + *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity +) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Optional context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + else: + if context: + template_text = template.text.replace("{#context#}", context) + else: + template_text = template.text + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 744dfd3d8d..e855ab2d2b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - system_query=query, + user_query=query, memory=memory, model_config=model_config, - files=files, + user_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=[], ) # handle invoke result diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index 70ae29d45f..474ecefe76 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -1,11 +1,11 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum from typing import Optional from core.workflow.nodes.base import BaseNodeData -class WriteMode(str, Enum): +class WriteMode(StrEnum): OVER_WRITE = "over-write" APPEND = "append" CLEAR = "clear" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 4984ea5a23..9b2da642ac 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,10 +5,9 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.app_config.entities import FileUploadConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.models import File, FileTransferMethod, ImageConfig +from core.file.models import File from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError @@ -18,9 +17,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes import NodeType -from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent -from core.workflow.nodes.llm import LLMNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping from factories import file_factory from models.enums import UserFrom @@ -115,7 +113,12 @@ class WorkflowEntry: @classmethod def single_step_run( - cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict + cls, + *, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -135,13 +138,9 @@ class WorkflowEntry: raise ValueError("nodes not found in workflow graph") # fetch node config from node id - node_config = None - for node in nodes: - if node.get("id") == node_id: - node_config = node - break - - if not node_config: + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: raise ValueError("node id not found in workflow graph") # Get node class @@ -153,11 +152,7 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - ) + variable_pool = VariablePool(environment_variables=workflow.environment_variables) # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -183,28 +178,24 @@ class WorkflowEntry: try: # variable selector to variable mapping - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=node_config - ) - except NotImplementedError: - variable_mapping = {} - - cls.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=node_instance.node_data, + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=node_config ) + except NotImplementedError: + variable_mapping = {} + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + ) + try: # run node generator = node_instance.run() - - return node_instance, generator except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + return node_instance, generator @classmethod def run_free_node( @@ -332,12 +323,11 @@ class WorkflowEntry: @classmethod def mapping_user_inputs_to_variable_pool( cls, + *, variable_mapping: Mapping[str, Sequence[str]], user_inputs: dict, variable_pool: VariablePool, tenant_id: str, - node_type: NodeType, - node_data: BaseNodeData, ) -> None: for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable @@ -355,40 +345,21 @@ class WorkflowEntry: # fetch variable node id from variable selector variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] - variable_key_list = cast(list[str], variable_key_list) + variable_key_list = list(variable_key_list) # get input value input_value = user_inputs.get(node_variable) if not input_value: input_value = user_inputs.get(node_variable_key) - # FIXME: temp fix for image type - if node_type == NodeType.LLM: - new_value = [] - if isinstance(input_value, list): - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in input_value: - if isinstance(item, dict) and "type" in item and item["type"] == "image": - transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - mapping = { - "id": item.get("id"), - "transfer_method": transfer_method, - "upload_file_id": item.get("upload_file_id"), - "url": item.get("url"), - } - config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - new_value.append(file) - - if new_value: - input_value = new_value + if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: + input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + if ( + isinstance(input_value, list) + and all(isinstance(item, dict) for item in input_value) + and all("type" in item and "transfer_method" in item for item in input_value) + ): + input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) # append variable and value to variable pool variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 5af45e1e50..24fa013697 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -33,7 +33,7 @@ def handle(sender, **kwargs): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py index a80572c0de..f225ef8e88 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created @@ -17,5 +17,5 @@ def handle(sender, **kwargs): db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == application_generate_entity.model_conf.provider, - ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) + ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) db.session.commit() diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 11a7544274..b26caa8671 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -67,7 +67,7 @@ class AzureBlobStorage(BaseStorage): account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), + expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index 415bf251f6..e7fa405afa 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class StorageType(str, Enum): +class StorageType(StrEnum): ALIYUN_OSS = "aliyun-oss" AZURE_BLOB = "azure-blob" BAIDU_OBS = "baidu-obs" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 4da0140d19..ad8dba8190 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,10 +1,11 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence -from typing import Any +from typing import Any, cast import httpx from sqlalchemy import select +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.helper import ssrf_proxy from extensions.ext_database import db @@ -71,7 +72,12 @@ def build_from_mapping( transfer_method=transfer_method, ) - if not _is_file_valid_with_config(file=file, config=config): + if not _is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension, + file_transfer_method=file.transfer_method, + config=config, + ): raise ValueError(f"File validation failed for file: {file.filename}") return file @@ -80,12 +86,9 @@ def build_from_mapping( def build_from_mappings( *, mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None, + config: FileUploadConfig | None = None, tenant_id: str, ) -> Sequence[File]: - if not config: - return [] - files = [ build_from_mapping( mapping=mapping, @@ -96,13 +99,14 @@ def build_from_mappings( ] if ( + config # If image config is set. - config.image_config + and config.image_config # And the number of image files exceeds the maximum limit and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits ): raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config.number_limits and len(files) > config.number_limits: + if config and config.number_limits and len(files) > config.number_limits: raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") return files @@ -114,17 +118,18 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: - file_type = FileType.value_of(mapping.get("type")) stmt = select(UploadFile).where( UploadFile.id == mapping.get("upload_file_id"), UploadFile.tenant_id == tenant_id, ) row = db.session.scalar(stmt) - if row is None: raise ValueError("Invalid upload file") + file_type = FileType(mapping.get("type")) + file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type) + return File( id=mapping.get("id"), filename=row.name, @@ -152,11 +157,14 @@ def _build_from_remote_url( mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + file_type = FileType(mapping.get("type")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type) + return File( id=mapping.get("id"), filename=filename, tenant_id=tenant_id, - type=FileType.value_of(mapping.get("type")), + type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -171,6 +179,7 @@ def _get_remote_file_info(url: str): mime_type = mimetypes.guess_type(filename)[0] or "" resp = ssrf_proxy.head(url, follow_redirects=True) + resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) @@ -180,20 +189,6 @@ def _get_remote_file_info(url: str): return mime_type, filename, file_size -def _get_file_type_by_mimetype(mime_type: str) -> FileType: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - def _build_from_tool_file( *, mapping: Mapping[str, Any], @@ -213,7 +208,8 @@ def _build_from_tool_file( raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype)) + file_type = FileType(mapping.get("type")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype) return File( id=mapping.get("id"), @@ -229,18 +225,72 @@ def _build_from_tool_file( ) -def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: - if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: +def _is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): return False - if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions: + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): return False - if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods: + if config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: return False - if file.type == FileType.IMAGE and config.image_config: - if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + if input_file_type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods: return False return True + + +def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType: + """ + If custom type, try to guess the file type by extension and mime_type. + """ + if file_type != FileType.CUSTOM: + return FileType(file_type) + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = _get_file_type_by_mimetype(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + extension = extension.lstrip(".") + if extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + elif extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + elif extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + elif extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + + +def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: + if "image" in mime_type: + file_type = FileType.IMAGE + elif "video" in mime_type: + file_type = FileType.VIDEO + elif "audio" in mime_type: + file_type = FileType.AUDIO + elif "text" in mime_type or "pdf" in mime_type: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + return file_type diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index aa353a3cc1..abb27fdad1 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -190,3 +190,12 @@ app_site_fields = { "show_workflow_steps": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean, } + +app_import_fields = { + "id": fields.String, + "status": fields.String, + "app_id": fields.String, + "current_dsl_version": fields.String, + "imported_dsl_version": fields.String, + "error": fields.String, +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index b32423f10c..533e3a0837 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -41,6 +41,7 @@ dataset_retrieval_model_fields = { external_retrieval_model_fields = { "top_k": fields.Integer, "score_threshold": fields.Float, + "score_threshold_enabled": fields.Boolean, } tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} diff --git a/api/libs/helper.py b/api/libs/helper.py index dd7c3fb7d0..6b8da40562 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -31,9 +31,12 @@ class AppIconUrlField(fields.Raw): if obj is None: return None - from models.model import IconType + from models.model import App, IconType - if obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, dict) and "app" in obj: + obj = obj["app"] + + if isinstance(obj, App) and obj.icon_type == IconType.IMAGE.value: return file_helpers.get_signed_file_url(obj.icon) return None diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index e747ea97ad..53aa0f2d45 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -70,7 +70,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -106,7 +106,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -141,7 +141,7 @@ class NotionOAuth(OAuthDataSource): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: raise ValueError("Data source binding not found") diff --git a/api/models/account.py b/api/models/account.py index 0a2c492ae6..9eb52573d3 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -10,7 +10,7 @@ from models.base import Base from .types import StringUUID -class AccountStatus(str, enum.Enum): +class AccountStatus(enum.StrEnum): PENDING = "pending" UNINITIALIZED = "uninitialized" ACTIVE = "active" @@ -111,6 +111,10 @@ class Account(UserMixin, Base): def is_admin_or_owner(self): return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) + @property + def is_admin(self): + return TenantAccountRole.is_admin_role(self._current_tenant.current_role) + @property def is_editor(self): return TenantAccountRole.is_editing_role(self._current_tenant.current_role) @@ -124,12 +128,12 @@ class Account(UserMixin, Base): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR -class TenantStatus(str, enum.Enum): +class TenantStatus(enum.StrEnum): NORMAL = "normal" ARCHIVE = "archive" -class TenantAccountRole(str, enum.Enum): +class TenantAccountRole(enum.StrEnum): OWNER = "owner" ADMIN = "admin" EDITOR = "editor" @@ -150,6 +154,10 @@ class TenantAccountRole(str, enum.Enum): def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + @staticmethod + def is_admin_role(role: str) -> bool: + return role and role == TenantAccountRole.ADMIN + @staticmethod def is_non_owner_role(role: str) -> bool: return role and role in { diff --git a/api/models/dataset.py b/api/models/dataset.py index a8b2c419d1..8ab957e875 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -23,7 +23,7 @@ from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID -class DatasetPermissionEnum(str, enum.Enum): +class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" PARTIAL_TEAM = "partial_members" diff --git a/api/models/enums.py b/api/models/enums.py index a83d35e042..7b9500ebe4 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,16 +1,16 @@ -from enum import Enum +from enum import StrEnum -class CreatedByRole(str, Enum): +class CreatedByRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" -class UserFrom(str, Enum): +class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" -class WorkflowRunTriggeredFrom(str, Enum): +class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py index 1a9ebbc0a6..c67c1051f2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from models.workflow import Workflow +from enum import StrEnum from typing import Any, Literal import sqlalchemy as sa @@ -38,7 +39,7 @@ class DifySetup(Base): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(str, Enum): +class AppMode(StrEnum): COMPLETION = "completion" WORKFLOW = "workflow" CHAT = "chat" @@ -74,7 +75,7 @@ class App(Base): name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) - icon_type = db.Column(db.String(255), nullable=True) + icon_type = db.Column(db.String(255), nullable=True) # image, emoji icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(StringUUID, nullable=True) @@ -261,8 +262,8 @@ class AppModelConfig(Base): return app @property - def model_dict(self): - return json.loads(self.model) if self.model else None + def model_dict(self) -> dict: + return json.loads(self.model) if self.model else {} @property def suggested_questions_list(self) -> list: @@ -610,11 +611,8 @@ class Conversation(Base): app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() ) - - if not app_model_config: - return {} - - model_config = app_model_config.to_dict() + if app_model_config: + model_config = app_model_config.to_dict() model_config["model_id"] = self.model_id model_config["provider"] = self.model_provider diff --git a/api/models/task.py b/api/models/task.py index 6fab2a72c2..b6783873df 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from celery import states @@ -17,8 +17,8 @@ class CeleryTask(Base): result = db.Column(db.PickleType, nullable=True) date_done = db.Column( db.DateTime, - default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + default=lambda: datetime.now(UTC).replace(tzinfo=None), + onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True, ) traceback = db.Column(db.Text, nullable=True) @@ -38,4 +38,4 @@ class CeleryTaskSet(Base): id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) diff --git a/api/models/workflow.py b/api/models/workflow.py index 721082d06d..6545bdc0ac 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union @@ -112,7 +112,7 @@ class Workflow(Base): ) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp() + sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" diff --git a/api/poetry.lock b/api/poetry.lock index dc9ebd1b75..4d9c12b191 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -114,7 +114,6 @@ files = [ [package.dependencies] aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" @@ -483,10 +482,8 @@ files = [ ] [package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] @@ -519,9 +516,6 @@ files = [ {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, ] -[package.dependencies] -typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} - [package.extras] tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] @@ -951,6 +945,10 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -963,8 +961,14 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, + {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, + {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -975,8 +979,24 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, + {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, + {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, + {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, + {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -986,6 +1006,10 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -997,6 +1021,10 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -1009,6 +1037,10 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -1021,6 +1053,10 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -1092,10 +1128,8 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "os_name == \"nt\""} -importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} packaging = ">=19.1" pyproject_hooks = "*" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} [package.extras] docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] @@ -1388,36 +1422,40 @@ files = [ [[package]] name = "chroma-hnswlib" -version = "0.7.3" +version = "0.7.6" description = "Chromas fork of hnswlib" optional = false python-versions = "*" files = [ - {file = "chroma-hnswlib-0.7.3.tar.gz", hash = "sha256:b6137bedde49fffda6af93b0297fe00429fc61e5a072b1ed9377f909ed95a932"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59d6a7c6f863c67aeb23e79a64001d537060b6995c3eca9a06e349ff7b0998ca"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d71a3f4f232f537b6152947006bd32bc1629a8686df22fd97777b70f416c127a"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c92dc1ebe062188e53970ba13f6b07e0ae32e64c9770eb7f7ffa83f149d4210"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49da700a6656fed8753f68d44b8cc8ae46efc99fc8a22a6d970dc1697f49b403"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:108bc4c293d819b56476d8f7865803cb03afd6ca128a2a04d678fffc139af029"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:11e7ca93fb8192214ac2b9c0943641ac0daf8f9d4591bb7b73be808a83835667"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f552e4d23edc06cdeb553cdc757d2fe190cdeb10d43093d6a3319f8d4bf1c6b"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f96f4d5699e486eb1fb95849fe35ab79ab0901265805be7e60f4eaa83ce263ec"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:368e57fe9ebae05ee5844840fa588028a023d1182b0cfdb1d13f607c9ea05756"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:b7dca27b8896b494456db0fd705b689ac6b73af78e186eb6a42fea2de4f71c6f"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:70f897dc6218afa1d99f43a9ad5eb82f392df31f57ff514ccf4eeadecd62f544"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aef10b4952708f5a1381c124a29aead0c356f8d7d6e0b520b778aaa62a356f4"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ee2d8d1529fca3898d512079144ec3e28a81d9c17e15e0ea4665697a7923253"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a4021a70e898783cd6f26e00008b494c6249a7babe8774e90ce4766dd288c8ba"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a8f61fa1d417fda848e3ba06c07671f14806a2585272b175ba47501b066fe6b1"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d7563be58bc98e8f0866907368e22ae218d6060601b79c42f59af4eccbbd2e0a"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51b8d411486ee70d7b66ec08cc8b9b6620116b650df9c19076d2d8b6ce2ae914"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d706782b628e4f43f1b8a81e9120ac486837fbd9bcb8ced70fe0d9b95c72d77"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:54f053dedc0e3ba657f05fec6e73dd541bc5db5b09aa8bc146466ffb734bdc86"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e607c5a71c610a73167a517062d302c0827ccdd6e259af6e4869a5c1306ffb5d"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c2358a795870156af6761890f9eb5ca8cade57eb10c5f046fe94dae1faa04b9e"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cea425df2e6b8a5e201fff0d922a1cc1d165b3cfe762b1408075723c8892218"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:454df3dd3e97aa784fba7cf888ad191e0087eef0fd8c70daf28b753b3b591170"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:df587d15007ca701c6de0ee7d5585dd5e976b7edd2b30ac72bc376b3c3f85882"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f35192fbbeadc8c0633f0a69c3d3e9f1a4eab3a46b65458bbcbcabdd9e895c36"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f007b608c96362b8f0c8b6b2ac94f67f83fcbabd857c378ae82007ec92f4d82"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:456fd88fa0d14e6b385358515aef69fc89b3c2191706fd9aee62087b62aad09c"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dfaae825499c2beaa3b75a12d7ec713b64226df72a5c4097203e3ed532680da"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:2487201982241fb1581be26524145092c95902cb09fc2646ccfbc407de3328ec"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81181d54a2b1e4727369486a631f977ffc53c5533d26e3d366dda243fb0998ca"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b4ab4e11f1083dd0a11ee4f0e0b183ca9f0f2ed63ededba1935b13ce2b3606f"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53db45cd9173d95b4b0bdccb4dbff4c54a42b51420599c32267f3abbeb795170"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c093f07a010b499c00a15bc9376036ee4800d335360570b14f7fe92badcdcf9"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:0540b0ac96e47d0aa39e88ea4714358ae05d64bbe6bf33c52f316c664190a6a3"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e87e9b616c281bfbe748d01705817c71211613c3b063021f7ed5e47173556cb7"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec5ca25bc7b66d2ecbf14502b5729cde25f70945d22f2aaf523c2d747ea68912"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305ae491de9d5f3c51e8bd52d84fdf2545a4a2bc7af49765cda286b7bb30b1d4"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:822ede968d25a2c88823ca078a58f92c9b5c4142e38c7c8b4c48178894a0a3c5"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2fe6ea949047beed19a94b33f41fe882a691e58b70c55fdaa90274ae78be046f"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feceff971e2a2728c9ddd862a9dd6eb9f638377ad98438876c9aeac96c9482f5"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0633b60e00a2b92314d0bf5bbc0da3d3320be72c7e3f4a9b19f4609dc2b2ab"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a566abe32fab42291f766d667bdbfa234a7f457dcbd2ba19948b7a978c8ca624"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6be47853d9a58dedcfa90fc846af202b071f028bbafe1d8711bf64fe5a7f6111"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a7af35bdd39a88bffa49f9bb4bf4f9040b684514a024435a1ef5cdff980579d"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53b1f1551f2b5ad94eb610207bde1bb476245fc5097a2bec2b476c653c58bde"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3085402958dbdc9ff5626ae58d696948e715aef88c86d1e3f9285a88f1afd3bc"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:77326f658a15adfb806a16543f7db7c45f06fd787d699e643642d6bde8ed49c4"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:93b056ab4e25adab861dfef21e1d2a2756b18be5bc9c292aa252fa12bb44e6ae"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fe91f018b30452c16c811fd6c8ede01f84e5a9f3c23e0758775e57f1c3778871"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6c0e627476f0f4d9e153420d36042dd9c6c3671cfd1fe511c0253e38c2a1039"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9796a4536b7de6c6d76a792ba03e08f5aaa53e97e052709568e50b4d20c04f"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:d30e2db08e7ffdcc415bd072883a322de5995eb6ec28a8f8c054103bbd3ec1e0"}, + {file = "chroma_hnswlib-0.7.6.tar.gz", hash = "sha256:4dce282543039681160259d29fcde6151cc9106c6461e0485f57cdccd83059b7"}, ] [package.dependencies] @@ -1425,26 +1463,26 @@ numpy = "*" [[package]] name = "chromadb" -version = "0.5.1" +version = "0.5.20" description = "Chroma." optional = false python-versions = ">=3.8" files = [ - {file = "chromadb-0.5.1-py3-none-any.whl", hash = "sha256:61f1f75a672b6edce7f1c8875c67e2aaaaf130dc1c1684431fbc42ad7240d01d"}, - {file = "chromadb-0.5.1.tar.gz", hash = "sha256:e2b2b6a34c2a949bedcaa42fa7775f40c7f6667848fc8094dcbf97fc0d30bee7"}, + {file = "chromadb-0.5.20-py3-none-any.whl", hash = "sha256:9550ba1b6dce911e35cac2568b301badf4b42f457b99a432bdeec2b6b9dd3680"}, + {file = "chromadb-0.5.20.tar.gz", hash = "sha256:19513a23b2d20059866216bfd80195d1d4a160ffba234b8899f5e80978160ca7"}, ] [package.dependencies] bcrypt = ">=4.0.1" build = ">=1.0.3" -chroma-hnswlib = "0.7.3" +chroma-hnswlib = "0.7.6" fastapi = ">=0.95.2" grpcio = ">=1.58.0" httpx = ">=0.27.0" importlib-resources = "*" kubernetes = ">=28.1.0" mmh3 = ">=4.0.1" -numpy = ">=1.22.5,<2.0.0" +numpy = ">=1.22.5" onnxruntime = ">=1.14.1" opentelemetry-api = ">=1.2.0" opentelemetry-exporter-otlp-proto-grpc = ">=1.2.0" @@ -1456,7 +1494,7 @@ posthog = ">=2.4.0" pydantic = ">=1.9" pypika = ">=0.48.9" PyYAML = ">=6.0.0" -requests = ">=2.28" +rich = ">=10.11.0" tenacity = ">=8.2.3" tokenizers = ">=0.13.2" tqdm = ">=4.65.0" @@ -2036,9 +2074,6 @@ files = [ {file = "dataclass_wizard-0.28.0-py2.py3-none-any.whl", hash = "sha256:996fa46475b9192a48a057c34f04597bc97be5bc2f163b99cb1de6f778ca1f7f"}, ] -[package.dependencies] -typing-extensions = {version = ">=4", markers = "python_version == \"3.9\" or python_version == \"3.10\""} - [package.extras] dev = ["Sphinx (==7.4.7)", "Sphinx (==8.1.3)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dataclass-factory (==2.16)", "dataclass-wizard[toml]", "dataclasses-json (==0.6.7)", "flake8 (>=3)", "jsons (==1.6.3)", "pip (>=21.3.1)", "pytest (==8.3.3)", "pytest-cov (==6.0.0)", "pytest-mock (>=3.6.1)", "pytimeparse (==1.1.8)", "sphinx-issues (==5.0.0)", "tomli (>=2,<3)", "tomli (>=2,<3)", "tomli-w (>=1,<2)", "tox (==4.23.2)", "twine (==5.1.1)", "watchdog[watchmedo] (==6.0.0)", "wheel (==0.45.0)"] timedelta = ["pytimeparse (>=1.1.7)"] @@ -2410,18 +2445,19 @@ files = [ tests = ["pytest"] [[package]] -name = "exceptiongroup" -version = "1.2.2" -description = "Backport of PEP 654 (exception groups)" +name = "faker" +version = "32.1.0" +description = "Faker is a Python package that generates fake data for you." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, - {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, + {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, + {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, ] -[package.extras] -test = ["pytest (>=6)"] +[package.dependencies] +python-dateutil = ">=2.4" +typing-extensions = "*" [[package]] name = "fal-client" @@ -3030,59 +3066,54 @@ files = [ [[package]] name = "gevent" -version = "23.9.1" +version = "24.11.1" description = "Coroutine-based network library" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "gevent-23.9.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:a3c5e9b1f766a7a64833334a18539a362fb563f6c4682f9634dea72cbe24f771"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b101086f109168b23fa3586fccd1133494bdb97f86920a24dc0b23984dc30b69"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36a549d632c14684bcbbd3014a6ce2666c5f2a500f34d58d32df6c9ea38b6535"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:272cffdf535978d59c38ed837916dfd2b5d193be1e9e5dcc60a5f4d5025dd98a"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcb8612787a7f4626aa881ff15ff25439561a429f5b303048f0fca8a1c781c39"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:d57737860bfc332b9b5aa438963986afe90f49645f6e053140cfa0fa1bdae1ae"}, - {file = "gevent-23.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5f3c781c84794926d853d6fb58554dc0dcc800ba25c41d42f6959c344b4db5a6"}, - {file = "gevent-23.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dbb22a9bbd6a13e925815ce70b940d1578dbe5d4013f20d23e8a11eddf8d14a7"}, - {file = "gevent-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:707904027d7130ff3e59ea387dddceedb133cc742b00b3ffe696d567147a9c9e"}, - {file = "gevent-23.9.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:45792c45d60f6ce3d19651d7fde0bc13e01b56bb4db60d3f32ab7d9ec467374c"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e24c2af9638d6c989caffc691a039d7c7022a31c0363da367c0d32ceb4a0648"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e1ead6863e596a8cc2a03e26a7a0981f84b6b3e956101135ff6d02df4d9a6b07"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65883ac026731ac112184680d1f0f1e39fa6f4389fd1fc0bf46cc1388e2599f9"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7af500da05363e66f122896012acb6e101a552682f2352b618e541c941a011"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c3e5d2fa532e4d3450595244de8ccf51f5721a05088813c1abd93ad274fe15e7"}, - {file = "gevent-23.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c84d34256c243b0a53d4335ef0bc76c735873986d478c53073861a92566a8d71"}, - {file = "gevent-23.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ada07076b380918829250201df1d016bdafb3acf352f35e5693b59dceee8dd2e"}, - {file = "gevent-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:921dda1c0b84e3d3b1778efa362d61ed29e2b215b90f81d498eb4d8eafcd0b7a"}, - {file = "gevent-23.9.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ed7a048d3e526a5c1d55c44cb3bc06cfdc1947d06d45006cc4cf60dedc628904"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c1abc6f25f475adc33e5fc2dbcc26a732608ac5375d0d306228738a9ae14d3b"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4368f341a5f51611411ec3fc62426f52ac3d6d42eaee9ed0f9eebe715c80184e"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:52b4abf28e837f1865a9bdeef58ff6afd07d1d888b70b6804557e7908032e599"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52e9f12cd1cda96603ce6b113d934f1aafb873e2c13182cf8e86d2c5c41982ea"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:de350fde10efa87ea60d742901e1053eb2127ebd8b59a7d3b90597eb4e586599"}, - {file = "gevent-23.9.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fde6402c5432b835fbb7698f1c7f2809c8d6b2bd9d047ac1f5a7c1d5aa569303"}, - {file = "gevent-23.9.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:dd6c32ab977ecf7c7b8c2611ed95fa4aaebd69b74bf08f4b4960ad516861517d"}, - {file = "gevent-23.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:455e5ee8103f722b503fa45dedb04f3ffdec978c1524647f8ba72b4f08490af1"}, - {file = "gevent-23.9.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:7ccf0fd378257cb77d91c116e15c99e533374a8153632c48a3ecae7f7f4f09fe"}, - {file = "gevent-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d163d59f1be5a4c4efcdd13c2177baaf24aadf721fdf2e1af9ee54a998d160f5"}, - {file = "gevent-23.9.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7532c17bc6c1cbac265e751b95000961715adef35a25d2b0b1813aa7263fb397"}, - {file = "gevent-23.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:78eebaf5e73ff91d34df48f4e35581ab4c84e22dd5338ef32714264063c57507"}, - {file = "gevent-23.9.1-cp38-cp38-win32.whl", hash = "sha256:f632487c87866094546a74eefbca2c74c1d03638b715b6feb12e80120960185a"}, - {file = "gevent-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:62d121344f7465e3739989ad6b91f53a6ca9110518231553fe5846dbe1b4518f"}, - {file = "gevent-23.9.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:bf456bd6b992eb0e1e869e2fd0caf817f0253e55ca7977fd0e72d0336a8c1c6a"}, - {file = "gevent-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43daf68496c03a35287b8b617f9f91e0e7c0d042aebcc060cadc3f049aadd653"}, - {file = "gevent-23.9.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:7c28e38dcde327c217fdafb9d5d17d3e772f636f35df15ffae2d933a5587addd"}, - {file = "gevent-23.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fae8d5b5b8fa2a8f63b39f5447168b02db10c888a3e387ed7af2bd1b8612e543"}, - {file = "gevent-23.9.1-cp39-cp39-win32.whl", hash = "sha256:2c7b5c9912378e5f5ccf180d1fdb1e83f42b71823483066eddbe10ef1a2fcaa2"}, - {file = "gevent-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:a2898b7048771917d85a1d548fd378e8a7b2ca963db8e17c6d90c76b495e0e2b"}, - {file = "gevent-23.9.1.tar.gz", hash = "sha256:72c002235390d46f94938a96920d8856d4ffd9ddf62a303a0d7c118894097e34"}, + {file = "gevent-24.11.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:92fe5dfee4e671c74ffaa431fd7ffd0ebb4b339363d24d0d944de532409b935e"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bfcfe08d038e1fa6de458891bca65c1ada6d145474274285822896a858c870"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7398c629d43b1b6fd785db8ebd46c0a353880a6fab03d1cf9b6788e7240ee32e"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d7886b63ebfb865178ab28784accd32f287d5349b3ed71094c86e4d3ca738af5"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9ca80711e6553880974898d99357fb649e062f9058418a92120ca06c18c3c59"}, + {file = "gevent-24.11.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e24181d172f50097ac8fc272c8c5b030149b630df02d1c639ee9f878a470ba2b"}, + {file = "gevent-24.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1d4fadc319b13ef0a3c44d2792f7918cf1bca27cacd4d41431c22e6b46668026"}, + {file = "gevent-24.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:3d882faa24f347f761f934786dde6c73aa6c9187ee710189f12dcc3a63ed4a50"}, + {file = "gevent-24.11.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:351d1c0e4ef2b618ace74c91b9b28b3eaa0dd45141878a964e03c7873af09f62"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5efe72e99b7243e222ba0c2c2ce9618d7d36644c166d63373af239da1036bab"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d3b249e4e1f40c598ab8393fc01ae6a3b4d51fc1adae56d9ba5b315f6b2d758"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81d918e952954675f93fb39001da02113ec4d5f4921bf5a0cc29719af6824e5d"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9c935b83d40c748b6421625465b7308d87c7b3717275acd587eef2bd1c39546"}, + {file = "gevent-24.11.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff96c5739834c9a594db0e12bf59cb3fa0e5102fc7b893972118a3166733d61c"}, + {file = "gevent-24.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d6c0a065e31ef04658f799215dddae8752d636de2bed61365c358f9c91e7af61"}, + {file = "gevent-24.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:97e2f3999a5c0656f42065d02939d64fffaf55861f7d62b0107a08f52c984897"}, + {file = "gevent-24.11.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:a3d75fa387b69c751a3d7c5c3ce7092a171555126e136c1d21ecd8b50c7a6e46"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:beede1d1cff0c6fafae3ab58a0c470d7526196ef4cd6cc18e7769f207f2ea4eb"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85329d556aaedced90a993226d7d1186a539c843100d393f2349b28c55131c85"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:816b3883fa6842c1cf9d2786722014a0fd31b6312cca1f749890b9803000bad6"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b24d800328c39456534e3bc3e1684a28747729082684634789c2f5a8febe7671"}, + {file = "gevent-24.11.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a5f1701ce0f7832f333dd2faf624484cbac99e60656bfbb72504decd42970f0f"}, + {file = "gevent-24.11.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d740206e69dfdfdcd34510c20adcb9777ce2cc18973b3441ab9767cd8948ca8a"}, + {file = "gevent-24.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:68bee86b6e1c041a187347ef84cf03a792f0b6c7238378bf6ba4118af11feaae"}, + {file = "gevent-24.11.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:d618e118fdb7af1d6c1a96597a5cd6ac84a9f3732b5be8515c6a66e098d498b6"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2142704c2adce9cd92f6600f371afb2860a446bfd0be5bd86cca5b3e12130766"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92e0d7759de2450a501effd99374256b26359e801b2d8bf3eedd3751973e87f5"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca845138965c8c56d1550499d6b923eb1a2331acfa9e13b817ad8305dde83d11"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:356b73d52a227d3313f8f828025b665deada57a43d02b1cf54e5d39028dbcf8d"}, + {file = "gevent-24.11.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:58851f23c4bdb70390f10fc020c973ffcf409eb1664086792c8b1e20f25eef43"}, + {file = "gevent-24.11.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1ea50009ecb7f1327347c37e9eb6561bdbc7de290769ee1404107b9a9cba7cf1"}, + {file = "gevent-24.11.1-cp313-cp313-win_amd64.whl", hash = "sha256:ec68e270543ecd532c4c1d70fca020f90aa5486ad49c4f3b8b2e64a66f5c9274"}, + {file = "gevent-24.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9347690f4e53de2c4af74e62d6fabc940b6d4a6cad555b5a379f61e7d3f2a8e"}, + {file = "gevent-24.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8619d5c888cb7aebf9aec6703e410620ef5ad48cdc2d813dd606f8aa7ace675f"}, + {file = "gevent-24.11.1-cp39-cp39-win32.whl", hash = "sha256:c6b775381f805ff5faf250e3a07c0819529571d19bb2a9d474bee8c3f90d66af"}, + {file = "gevent-24.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:1c3443b0ed23dcb7c36a748d42587168672953d368f2956b17fad36d43b58836"}, + {file = "gevent-24.11.1-pp310-pypy310_pp73-macosx_11_0_universal2.whl", hash = "sha256:f43f47e702d0c8e1b8b997c00f1601486f9f976f84ab704f8f11536e3fa144c9"}, + {file = "gevent-24.11.1.tar.gz", hash = "sha256:8bd1419114e9e4a3ed33a5bad766afff9a3cf765cb440a582a1b3a9bc80c1aca"}, ] [package.dependencies] -cffi = {version = ">=1.12.2", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""} -greenlet = [ - {version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""}, - {version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""}, -] +cffi = {version = ">=1.17.1", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""} +greenlet = {version = ">=3.1.1", markers = "platform_python_implementation == \"CPython\""} "zope.event" = "*" "zope.interface" = "*" @@ -3090,8 +3121,8 @@ greenlet = [ dnspython = ["dnspython (>=1.16.0,<2.0)", "idna"] docs = ["furo", "repoze.sphinx.autointerface", "sphinx", "sphinxcontrib-programoutput", "zope.schema"] monitor = ["psutil (>=5.7.0)"] -recommended = ["cffi (>=1.12.2)", "dnspython (>=1.16.0,<2.0)", "idna", "psutil (>=5.7.0)"] -test = ["cffi (>=1.12.2)", "coverage (>=5.0)", "dnspython (>=1.16.0,<2.0)", "idna", "objgraph", "psutil (>=5.7.0)", "requests", "setuptools"] +recommended = ["cffi (>=1.17.1)", "dnspython (>=1.16.0,<2.0)", "idna", "psutil (>=5.7.0)"] +test = ["cffi (>=1.17.1)", "coverage (>=5.0)", "dnspython (>=1.16.0,<2.0)", "idna", "objgraph", "psutil (>=5.7.0)", "requests"] [[package]] name = "gmpy2" @@ -3200,14 +3231,8 @@ files = [ [package.dependencies] google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" -grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] -grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] +grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} +grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -5540,9 +5565,6 @@ files = [ {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} - [[package]] name = "multiprocess" version = "0.70.17" @@ -6400,7 +6422,6 @@ bottleneck = {version = ">=1.3.6", optional = true, markers = "extra == \"perfor numba = {version = ">=0.56.4", optional = true, markers = "extra == \"performance\""} numexpr = {version = ">=2.8.4", optional = true, markers = "extra == \"performance\""} numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -6684,7 +6705,6 @@ files = [ deprecation = ">=2.1.0,<3.0.0" httpx = {version = ">=0.26,<0.28", extras = ["http2"]} pydantic = ">=1.9,<3.0" -strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""} [[package]] name = "posthog" @@ -7416,9 +7436,6 @@ files = [ {file = "pypdf-5.1.0.tar.gz", hash = "sha256:425a129abb1614183fd1aca6982f650b47f8026867c0ce7c4b9f281c443d2740"}, ] -[package.dependencies] -typing_extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} - [package.extras] crypto = ["cryptography"] cryptodome = ["PyCryptodome"] @@ -7507,11 +7524,9 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=1.5,<2" -tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -7549,7 +7564,6 @@ files = [ [package.dependencies] pytest = ">=8.3.3" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [package.extras] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] @@ -8367,7 +8381,6 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" -typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -9199,22 +9212,6 @@ httpx = {version = ">=0.26,<0.28", extras = ["http2"]} python-dateutil = ">=2.8.2,<3.0.0" typing-extensions = ">=4.2.0,<5.0.0" -[[package]] -name = "strenum" -version = "0.4.15" -description = "An Enum that inherits from str." -optional = false -python-versions = "*" -files = [ - {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, - {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, -] - -[package.extras] -docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] -release = ["twine"] -test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] - [[package]] name = "strictyaml" version = "1.7.3" @@ -9621,17 +9618,6 @@ files = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] -[[package]] -name = "tomli" -version = "2.1.0" -description = "A lil' TOML parser" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"}, - {file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"}, -] - [[package]] name = "tos" version = "2.7.2" @@ -10052,7 +10038,6 @@ h11 = ">=0.8" httptools = {version = ">=0.5.0", optional = true, markers = "extra == \"standard\""} python-dotenv = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} pyyaml = {version = ">=5.1", optional = true, markers = "extra == \"standard\""} -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} uvloop = {version = ">=0.14.0,<0.15.0 || >0.15.0,<0.15.1 || >0.15.1", optional = true, markers = "(sys_platform != \"win32\" and sys_platform != \"cygwin\") and platform_python_implementation != \"PyPy\" and extra == \"standard\""} watchfiles = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} websockets = {version = ">=10.4", optional = true, markers = "extra == \"standard\""} @@ -11035,5 +11020,5 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" -python-versions = ">=3.10,<3.13" -content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" +python-versions = ">=3.11,<3.13" +content-hash = "983ba4f2cb89f0c867fc50cb48677cad9343f7f0828c7082cb0b5cf171d716fb" diff --git a/api/pyproject.toml b/api/pyproject.toml index 0d87c1b1c8..79857f8163 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,5 +1,5 @@ [project] -requires-python = ">=3.10,<3.13" +requires-python = ">=3.11,<3.13" [build-system] requires = ["poetry-core"] @@ -131,7 +131,7 @@ flask-login = "~0.6.3" flask-migrate = "~4.0.5" flask-restful = "~0.3.10" flask-sqlalchemy = "~3.1.1" -gevent = "~23.9.1" +gevent = "~24.11.1" gmpy2 = "~2.2.1" google-ai-generativelanguage = "0.6.9" google-api-core = "2.18.0" @@ -163,7 +163,7 @@ pydantic-settings = "~2.6.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" -python = ">=3.10,<3.13" +python = ">=3.11,<3.13" python-docx = "~1.1.0" python-dotenv = "1.0.0" pyyaml = "~6.0.1" @@ -242,7 +242,7 @@ tos = "~2.7.1" [tool.poetry.group.vdb.dependencies] alibabacloud_gpdb20160503 = "~3.8.0" alibabacloud_tea_openapi = "~0.3.9" -chromadb = "0.5.1" +chromadb = "0.5.20" clickhouse-connect = "~0.7.16" couchbase = "~4.3.0" elasticsearch = "8.14.0" @@ -268,6 +268,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" +faker = "~32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/pytest.ini b/api/pytest.ini index a23a4b3f3d..993da4c9a7 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -20,6 +20,7 @@ env = OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 TEI_RERANK_SERVER_URL = http://a.abc.com:11451 + TEI_API_KEY = ttttttttttttttt UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa XINFERENCE_CHAT_MODEL_UID = chat diff --git a/api/services/account_service.py b/api/services/account_service.py index 3eca0e62e7..2ab784e2b5 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -4,7 +4,7 @@ import logging import random import secrets import uuid -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from hashlib import sha256 from typing import Any, Optional @@ -115,15 +115,15 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(timezone.utc).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.now(timezone.utc).replace(tzinfo=None) + if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): + account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return account @staticmethod def get_account_jwt_token(account: Account) -> str: - exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) exp = int(exp_dt.timestamp()) payload = { "user_id": account.id, @@ -160,7 +160,7 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -253,7 +253,7 @@ class AccountService: # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) else: # If it does not exist, create a new record account_integrate = AccountIntegrate( @@ -288,7 +288,7 @@ class AccountService: @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" - account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.last_login_at = datetime.now(UTC).replace(tzinfo=None) account.last_login_ip = ip_address db.session.add(account) db.session.commit() @@ -765,7 +765,7 @@ class RegisterService: ) account.last_login_ip = ip_address - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) @@ -805,7 +805,7 @@ class RegisterService: is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 915d37ec03..f45c21cb18 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -429,7 +429,7 @@ class AppAnnotationService: raise NotFound("App annotation not found") annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id - annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) db.session.commit() diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py new file mode 100644 index 0000000000..7ae91e8f42 --- /dev/null +++ b/api/services/app_dsl_service.py @@ -0,0 +1,663 @@ +import logging +import uuid +from enum import StrEnum +from typing import Optional +from uuid import uuid4 + +import yaml +from packaging import version +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import PluginDependency +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData +from core.workflow.nodes.tool.entities import ToolNodeData +from events.app_event import app_model_config_was_updated, app_was_created +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account, App, AppMode +from models.model import AppModelConfig +from models.workflow import Workflow +from services.plugin.dependencies_analysis import DependenciesAnalysisService +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" +IMPORT_INFO_REDIS_EXPIRY = 2 * 60 * 60 # 2 hours +CURRENT_DSL_VERSION = "0.1.3" +DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class Import(BaseModel): + id: str + status: ImportStatus + app_id: Optional[str] = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + dependencies: list[PluginDependency] = Field(default_factory=list) + error: str = "" + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # Compare major version and minor version + if current_ver.major != imported_ver.major or current_ver.minor != imported_ver.minor: + return ImportStatus.PENDING + + if current_ver.micro != imported_ver.micro: + return ImportStatus.COMPLETED_WITH_WARNINGS + + return ImportStatus.COMPLETED + + +class PendingData(BaseModel): + import_mode: str + yaml_content: str + name: str | None + description: str | None + icon_type: str | None + icon: str | None + icon_background: str | None + app_id: str | None + + +class AppDslService: + def __init__(self, session: Session): + self._session = session + + def import_app( + self, + *, + account: Account, + import_mode: str, + yaml_content: Optional[str] = None, + yaml_url: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + icon_type: Optional[str] = None, + icon: Optional[str] = None, + icon_background: Optional[str] = None, + app_id: Optional[str] = None, + ) -> Import: + """Import an app from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content + + if len(content) > DSL_MAX_SIZE: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="File size exceeds the limit of 10MB", + ) + + if not content: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Empty content from url", + ) + + try: + content = content.decode("utf-8") + except UnicodeDecodeError as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error decoding content: {e}", + ) + except Exception as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error fetching YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + + # Process YAML content + try: + # Parse YAML to validate format + data = yaml.safe_load(content) + if not isinstance(data, dict): + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: content must be a mapping", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + if not data.get("kind") or data.get("kind") != "app": + data["kind"] = "app" + + imported_version = data.get("version", "0.1.0") + status = _check_version_compatibility(imported_version) + + # Extract app data + app_data = data.get("app") + if not app_data: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Missing app data in YAML content", + ) + + # If app_id is provided, check if it exists + app = None + if app_id: + stmt = select(App).where(App.id == app_id, App.tenant_id == account.current_tenant_id) + app = self._session.scalar(stmt) + + if not app: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="App not found", + ) + + if app.mode not in [AppMode.WORKFLOW.value, AppMode.ADVANCED_CHAT.value]: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Only workflow or advanced chat apps can be overwritten", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + pending_data = PendingData( + import_mode=import_mode, + yaml_content=content, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + app_id=app_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending_data.model_dump_json(), + ) + + return Import( + id=import_id, + status=status, + app_id=app_id, + imported_dsl_version=imported_version, + ) + + try: + dependencies = self.check_dependencies(account.current_tenant_id, data.get("dependencies", [])) + except Exception as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + if len(dependencies) > 0: + return Import( + id=import_id, + status=ImportStatus.PENDING, + app_id=app_id, + imported_dsl_version=imported_version, + dependencies=dependencies, + ) + + # Create or update app + app = self._create_or_update_app( + app=app, + data=data, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + ) + + return Import( + id=import_id, + status=status, + app_id=app.id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import app") + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> Import: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + pending_data = PendingData.model_validate_json(pending_data) + data = yaml.safe_load(pending_data.yaml_content) + + app = None + if pending_data.app_id: + stmt = select(App).where(App.id == pending_data.app_id, App.tenant_id == account.current_tenant_id) + app = self._session.scalar(stmt) + + # Create or update app + app = self._create_or_update_app( + app=app, + data=data, + account=account, + name=pending_data.name, + description=pending_data.description, + icon_type=pending_data.icon_type, + icon=pending_data.icon, + icon_background=pending_data.icon_background, + ) + + # Delete import info from Redis + redis_client.delete(redis_key) + + return Import( + id=import_id, + status=ImportStatus.COMPLETED, + app_id=app.id, + current_dsl_version=CURRENT_DSL_VERSION, + imported_dsl_version=data.get("version", "0.1.0"), + ) + + except Exception as e: + logger.exception("Error confirming import") + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def _create_or_update_app( + self, + *, + app: Optional[App], + data: dict, + account: Account, + name: Optional[str] = None, + description: Optional[str] = None, + icon_type: Optional[str] = None, + icon: Optional[str] = None, + icon_background: Optional[str] = None, + ) -> App: + """Create a new app or update an existing one.""" + app_data = data.get("app", {}) + app_mode = AppMode(app_data["mode"]) + + # Set icon type + icon_type_value = icon_type or app_data.get("icon_type") + if icon_type_value in ["emoji", "link"]: + icon_type = icon_type_value + else: + icon_type = "emoji" + icon = icon or str(app_data.get("icon", "")) + + if app: + # Update existing app + app.name = name or app_data.get("name", app.name) + app.description = description or app_data.get("description", app.description) + app.icon_type = icon_type + app.icon = icon + app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) + app.updated_by = account.id + else: + # Create new app + app = App() + app.id = str(uuid4()) + app.tenant_id = account.current_tenant_id + app.mode = app_mode.value + app.name = name or app_data.get("name", "") + app.description = description or app_data.get("description", "") + app.icon_type = icon_type + app.icon = icon + app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") + app.enable_site = True + app.enable_api = True + app.use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) + app.created_by = account.id + app.updated_by = account.id + + self._session.add(app) + self._session.commit() + app_was_created.send(app, account=account) + + # Initialize app based on mode + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig().from_model_config_dict(model_config) + app_model_config.id = str(uuid4()) + app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id + + app.app_model_config_id = app_model_config.id + + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + else: + raise ValueError("Invalid app mode") + return app + + @classmethod + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: + """ + Export app + :param app_model: App instance + :return: + """ + app_mode = AppMode.value_of(app_model.mode) + + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "app", + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, + }, + } + + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) + else: + cls._append_model_config_export_data(export_data, app_model) + + return yaml.dump(export_data, allow_unicode=True) + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param app_model: App instance + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) + dependencies = cls._extract_dependencies_from_workflow(workflow) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=app_model.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: + """ + Append model config export data + :param export_data: export data + :param app_model: App instance + """ + app_model_config = app_model.app_model_config + if not app_model_config: + raise ValueError("Missing app configuration, please check.") + + export_data["model_config"] = app_model_config.to_dict() + dependencies = cls._extract_dependencies_from_model_config(app_model_config) + export_data["dependencies"] = [ + jsonable_encoder(d.model_dump()) + for d in DependenciesAnalysisService.generate_dependencies( + tenant_id=app_model.tenant_id, dependencies=dependencies + ) + ] + + @classmethod + def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: + """ + Extract dependencies from workflow + :param workflow: Workflow instance + :return: dependencies list format like ["langgenius/google"] + """ + graph = workflow.graph_dict + dependencies = [] + for node in graph.get("nodes", []): + try: + typ = node.get("data", {}).get("type") + match typ: + case NodeType.TOOL.value: + tool_entity = ToolNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), + ) + case NodeType.LLM.value: + llm_entity = LLMNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), + ) + case NodeType.QUESTION_CLASSIFIER.value: + question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + question_classifier_entity.model.provider + ), + ) + case NodeType.PARAMETER_EXTRACTOR.value: + parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + parameter_extractor_entity.model.provider + ), + ) + case NodeType.KNOWLEDGE_RETRIEVAL.value: + knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + if knowledge_retrieval_entity.retrieval_mode == "multiple": + if knowledge_retrieval_entity.multiple_retrieval_config: + if ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "reranking_model" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider + ), + ) + elif ( + knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode + == "weighted_score" + ): + if knowledge_retrieval_entity.multiple_retrieval_config.weights: + vector_setting = ( + knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting + ) + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + vector_setting.embedding_provider_name + ), + ) + elif knowledge_retrieval_entity.retrieval_mode == "single": + model_config = knowledge_retrieval_entity.single_retrieval_config + if model_config: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + model_config.model.provider + ), + ) + case _: + # TODO: Handle default case or unknown node types + pass + except Exception as e: + logger.exception("Error extracting node dependency", exc_info=e) + + return dependencies + + @classmethod + def _extract_dependencies_from_model_config(cls, model_config: AppModelConfig) -> list[str]: + """ + Extract dependencies from model config + :param model_config: AppModelConfig instance + :return: dependencies list format like ["langgenius/google:1.0.0@abcdef1234567890"] + """ + dependencies = [] + + try: + # completion model + model_dict = model_config.model_dict + if model_dict: + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", "")) + ) + + # reranking model + dataset_configs = model_config.dataset_configs_dict + if dataset_configs: + for dataset_config in dataset_configs: + if dataset_config.get("reranking_model"): + dependencies.append( + DependenciesAnalysisService.analyze_model_provider_dependency( + dataset_config.get("reranking_model", {}) + .get("reranking_provider_name", {}) + .get("provider") + ) + ) + + # tools + agent_configs = model_config.agent_mode_dict + if agent_configs: + for agent_config in agent_configs: + if agent_config.get("tools"): + for tool in agent_config.get("tools", []): + dependencies.append( + DependenciesAnalysisService.analyze_tool_dependency(tool.get("provider_id")) + ) + except Exception as e: + logger.exception("Error extracting model config dependency", exc_info=e) + + return dependencies + + @classmethod + def check_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + """ + Returns the leaked dependencies in current workspace + """ + dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + if not dependencies: + return [] + + return DependenciesAnalysisService.check_dependencies(tenant_id=tenant_id, dependencies=dependencies) diff --git a/api/services/app_dsl_service/__init__.py b/api/services/app_dsl_service/__init__.py deleted file mode 100644 index 9fc988ffb3..0000000000 --- a/api/services/app_dsl_service/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .service import AppDslService - -__all__ = ["AppDslService"] diff --git a/api/services/app_dsl_service/exc.py b/api/services/app_dsl_service/exc.py deleted file mode 100644 index 6da4b1938f..0000000000 --- a/api/services/app_dsl_service/exc.py +++ /dev/null @@ -1,34 +0,0 @@ -class DSLVersionNotSupportedError(ValueError): - """Raised when the imported DSL version is not supported by the current Dify version.""" - - -class InvalidYAMLFormatError(ValueError): - """Raised when the provided YAML format is invalid.""" - - -class MissingAppDataError(ValueError): - """Raised when the app data is missing in the provided DSL.""" - - -class InvalidAppModeError(ValueError): - """Raised when the app mode is invalid.""" - - -class MissingWorkflowDataError(ValueError): - """Raised when the workflow data is missing in the provided DSL.""" - - -class MissingModelConfigError(ValueError): - """Raised when the model config data is missing in the provided DSL.""" - - -class FileSizeLimitExceededError(ValueError): - """Raised when the file size exceeds the allowed limit.""" - - -class EmptyContentError(ValueError): - """Raised when the content fetched from the URL is empty.""" - - -class ContentDecodingError(ValueError): - """Raised when there is an error decoding the content.""" diff --git a/api/services/app_dsl_service/service.py b/api/services/app_dsl_service/service.py deleted file mode 100644 index 5fed40f2f5..0000000000 --- a/api/services/app_dsl_service/service.py +++ /dev/null @@ -1,666 +0,0 @@ -import logging -from collections.abc import Mapping -from typing import Any - -import yaml -from packaging import version - -from core.helper import ssrf_proxy -from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.entities.plugin import PluginDependency -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.tool.entities import ToolNodeData -from events.app_event import app_model_config_was_updated, app_was_created -from extensions.ext_database import db -from factories import variable_factory -from models.account import Account -from models.model import App, AppMode, AppModelConfig -from models.workflow import Workflow -from services.plugin.dependencies_analysis import DependenciesAnalysisService -from services.workflow_service import WorkflowService - -from .exc import ( - ContentDecodingError, - EmptyContentError, - FileSizeLimitExceededError, - InvalidAppModeError, - InvalidYAMLFormatError, - MissingAppDataError, - MissingModelConfigError, - MissingWorkflowDataError, -) - -logger = logging.getLogger(__name__) - -current_dsl_version = "0.1.3" -dsl_max_size = 10 * 1024 * 1024 # 10MB - - -class AppDslService: - @classmethod - def import_and_create_new_app_from_url(cls, tenant_id: str, url: str, args: dict, account: Account) -> App: - """ - Import app dsl from url and create new app - :param tenant_id: tenant id - :param url: import url - :param args: request args - :param account: Account instance - """ - response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) - response.raise_for_status() - content = response.content - - if len(content) > dsl_max_size: - raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") - - if not content: - raise EmptyContentError("Empty content from url") - - try: - data = content.decode("utf-8") - except UnicodeDecodeError as e: - raise ContentDecodingError(f"Error decoding content: {e}") - - return cls.import_and_create_new_app(tenant_id, data, args, account) - - @classmethod - def check_dependencies_from_url(cls, tenant_id: str, url: str, account: Account) -> list[PluginDependency]: - """ - Check dependencies from url - """ - response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) - response.raise_for_status() - content = response.content - - if len(content) > dsl_max_size: - raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") - - try: - data = content.decode("utf-8") - except UnicodeDecodeError as e: - raise ContentDecodingError(f"Error decoding content: {e}") - - return cls.check_dependencies(tenant_id, data, account) - - @classmethod - def check_dependencies(cls, tenant_id: str, data: str, account: Account) -> list[PluginDependency]: - """ - Returns the leaked dependencies in current workspace - """ - try: - import_data = yaml.safe_load(data) or {} - except yaml.YAMLError: - raise InvalidYAMLFormatError("Invalid YAML format in data argument.") - - dependencies = [PluginDependency(**dep) for dep in import_data.get("dependencies", [])] - if not dependencies: - return [] - - return DependenciesAnalysisService.check_dependencies(tenant_id=tenant_id, dependencies=dependencies) - - @classmethod - def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, account: Account) -> App: - """ - Import app dsl and create new app - :param tenant_id: tenant id - :param data: import data - :param args: request args - :param account: Account instance - """ - try: - import_data = yaml.safe_load(data) - except yaml.YAMLError: - raise InvalidYAMLFormatError("Invalid YAML format in data argument.") - - # check or repair dsl version - import_data = _check_or_fix_dsl(import_data) - - app_data = import_data.get("app") - if not app_data: - raise MissingAppDataError("Missing app in data argument") - - # get app basic info - name = args.get("name") or app_data.get("name") - description = args.get("description") or app_data.get("description", "") - icon_type = args.get("icon_type") or app_data.get("icon_type") - icon = args.get("icon") or app_data.get("icon") - icon_background = args.get("icon_background") or app_data.get("icon_background") - use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) - - # import dsl and create app - app_mode = AppMode.value_of(app_data.get("mode")) - - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow_data = import_data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - app = cls._import_and_create_new_workflow_based_app( - tenant_id=tenant_id, - app_mode=app_mode, - workflow_data=workflow_data, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: - model_config = import_data.get("model_config") - if not model_config or not isinstance(model_config, dict): - raise MissingModelConfigError( - "Missing model_config in data argument when app mode is chat, agent-chat or completion" - ) - - app = cls._import_and_create_new_model_config_based_app( - tenant_id=tenant_id, - app_mode=app_mode, - model_config_data=model_config, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - else: - raise InvalidAppModeError("Invalid app mode") - - return app - - @classmethod - def import_and_overwrite_workflow(cls, app_model: App, data: str, account: Account) -> Workflow: - """ - Import app dsl and overwrite workflow - :param app_model: App instance - :param data: import data - :param account: Account instance - """ - try: - import_data = yaml.safe_load(data) - except yaml.YAMLError: - raise InvalidYAMLFormatError("Invalid YAML format in data argument.") - - # check or repair dsl version - import_data = _check_or_fix_dsl(import_data) - - app_data = import_data.get("app") - if not app_data: - raise MissingAppDataError("Missing app in data argument") - - # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.") - - if app_data.get("mode") != app_model.mode: - raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") - - workflow_data = import_data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - return cls._import_and_overwrite_workflow_based_app( - app_model=app_model, - workflow_data=workflow_data, - account=account, - ) - - @classmethod - def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: - """ - Export app - :param app_model: App instance - :return: - """ - app_mode = AppMode.value_of(app_model.mode) - - export_data = { - "version": current_dsl_version, - "kind": "app", - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, - "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, - "description": app_model.description, - "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, - }, - } - - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - cls._append_workflow_export_data( - export_data=export_data, app_model=app_model, include_secret=include_secret - ) - else: - cls._append_model_config_export_data(export_data, app_model) - - return yaml.dump(export_data, allow_unicode=True) - - @classmethod - def _import_and_create_new_workflow_based_app( - cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: Mapping[str, Any], - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str, - use_icon_as_answer_icon: bool, - ) -> App: - """ - Import app dsl and create new workflow based app - - :param tenant_id: tenant id - :param app_mode: app mode - :param workflow_data: workflow data - :param account: Account instance - :param name: app name - :param description: app description - :param icon_type: app icon type, "emoji" or "image" - :param icon: app icon - :param icon_background: app icon background - :param use_icon_as_answer_icon: use app icon as answer icon - """ - if not workflow_data: - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - app = cls._create_app( - tenant_id=tenant_id, - app_mode=app_mode, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - - # init draft workflow - environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list - ] - workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=None, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) - - return app - - @classmethod - def _import_and_overwrite_workflow_based_app( - cls, app_model: App, workflow_data: Mapping[str, Any], account: Account - ) -> Workflow: - """ - Import app dsl and overwrite workflow based app - - :param app_model: App instance - :param workflow_data: workflow data - :param account: Account instance - """ - if not workflow_data: - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - # fetch draft workflow by app_model - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app_model) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - - # sync draft workflow - environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list - ] - draft_workflow = workflow_service.sync_draft_workflow( - app_model=app_model, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - - return draft_workflow - - @classmethod - def _import_and_create_new_model_config_based_app( - cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: Mapping[str, Any], - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str, - use_icon_as_answer_icon: bool, - ) -> App: - """ - Import app dsl and create new model config based app - - :param tenant_id: tenant id - :param app_mode: app mode - :param model_config_data: model config data - :param account: Account instance - :param name: app name - :param description: app description - :param icon: app icon - :param icon_background: app icon background - """ - if not model_config_data: - raise MissingModelConfigError( - "Missing model_config in data argument when app mode is chat, agent-chat or completion" - ) - - app = cls._create_app( - tenant_id=tenant_id, - app_mode=app_mode, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - app_model_config.created_by = account.id - app_model_config.updated_by = account.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - - app_model_config_was_updated.send(app, app_model_config=app_model_config) - - return app - - @classmethod - def _create_app( - cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str, - use_icon_as_answer_icon: bool, - ) -> App: - """ - Create new app - - :param tenant_id: tenant id - :param app_mode: app mode - :param account: Account instance - :param name: app name - :param description: app description - :param icon_type: app icon type, "emoji" or "image" - :param icon: app icon - :param icon_background: app icon background - :param use_icon_as_answer_icon: use app icon as answer icon - """ - app = App( - tenant_id=tenant_id, - mode=app_mode.value, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - enable_site=True, - enable_api=True, - use_icon_as_answer_icon=use_icon_as_answer_icon, - created_by=account.id, - updated_by=account.id, - ) - - db.session.add(app) - db.session.commit() - - app_was_created.send(app, account=account) - - return app - - @classmethod - def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: - """ - Append workflow export data - :param export_data: export data - :param app_model: App instance - """ - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - if not workflow: - raise ValueError("Missing draft workflow configuration, please check.") - - export_data["workflow"] = workflow.to_dict(include_secret=include_secret) - dependencies = cls._extract_dependencies_from_workflow(workflow) - export_data["dependencies"] = [ - jsonable_encoder(d.model_dump()) - for d in DependenciesAnalysisService.generate_dependencies( - tenant_id=app_model.tenant_id, dependencies=dependencies - ) - ] - - @classmethod - def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: - """ - Append model config export data - :param export_data: export data - :param app_model: App instance - """ - app_model_config = app_model.app_model_config - if not app_model_config: - raise ValueError("Missing app configuration, please check.") - - export_data["model_config"] = app_model_config.to_dict() - dependencies = cls._extract_dependencies_from_model_config(app_model_config) - export_data["dependencies"] = [ - jsonable_encoder(d.model_dump()) - for d in DependenciesAnalysisService.generate_dependencies( - tenant_id=app_model.tenant_id, dependencies=dependencies - ) - ] - - @classmethod - def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]: - """ - Extract dependencies from workflow - :param workflow: Workflow instance - :return: dependencies list format like ["langgenius/google"] - """ - graph = workflow.graph_dict - dependencies = [] - for node in graph.get("nodes", []): - try: - typ = node.get("data", {}).get("type") - match typ: - case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) - dependencies.append( - DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), - ) - case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), - ) - case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency( - question_classifier_entity.model.provider - ), - ) - case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency( - parameter_extractor_entity.model.provider - ), - ) - case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) - if knowledge_retrieval_entity.retrieval_mode == "multiple": - if knowledge_retrieval_entity.multiple_retrieval_config: - if ( - knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode - == "reranking_model" - ): - if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model: - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency( - knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider - ), - ) - elif ( - knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode - == "weighted_score" - ): - if knowledge_retrieval_entity.multiple_retrieval_config.weights: - vector_setting = ( - knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting - ) - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency( - vector_setting.embedding_provider_name - ), - ) - elif knowledge_retrieval_entity.retrieval_mode == "single": - model_config = knowledge_retrieval_entity.single_retrieval_config - if model_config: - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency( - model_config.model.provider - ), - ) - case _: - # Handle default case or unknown node types - pass - except Exception as e: - logger.exception("Error extracting node dependency", exc_info=e) - - return dependencies - - @classmethod - def _extract_dependencies_from_model_config(cls, model_config: AppModelConfig) -> list[str]: - """ - Extract dependencies from model config - :param model_config: AppModelConfig instance - :return: dependencies list format like ["langgenius/google:1.0.0@abcdef1234567890"] - """ - dependencies = [] - - try: - # completion model - model_dict = model_config.model_dict - if model_dict: - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider")) - ) - - # reranking model - dataset_configs = model_config.dataset_configs_dict - if dataset_configs: - for dataset_config in dataset_configs: - if dataset_config.get("reranking_model"): - dependencies.append( - DependenciesAnalysisService.analyze_model_provider_dependency( - dataset_config.get("reranking_model", {}) - .get("reranking_provider_name", {}) - .get("provider") - ) - ) - - # tools - agent_configs = model_config.agent_mode_dict - if agent_configs: - for agent_config in agent_configs: - if agent_config.get("tools"): - for tool in agent_config.get("tools", []): - dependencies.append( - DependenciesAnalysisService.analyze_tool_dependency(tool.get("provider_id")) - ) - except Exception as e: - logger.exception("Error extracting model config dependency", exc_info=e) - - return dependencies - - -def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]: - """ - Check or fix dsl - - :param import_data: import data - :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version - """ - if not import_data.get("version"): - import_data["version"] = "0.1.0" - - if not import_data.get("kind") or import_data.get("kind") != "app": - import_data["kind"] = "app" - - imported_version = import_data.get("version") - if imported_version != current_dsl_version: - if imported_version and version.parse(imported_version) > version.parse(current_dsl_version): - errmsg = ( - f"The imported DSL version {imported_version} is newer than " - f"the current supported version {current_dsl_version}. " - f"Please upgrade your Dify instance to import this configuration." - ) - logger.warning(errmsg) - # raise DSLVersionNotSupportedError(errmsg) - else: - logger.warning( - f"DSL version {imported_version} is older than " - f"the current version {current_dsl_version}. " - f"This may cause compatibility issues." - ) - - return import_data diff --git a/api/services/app_service.py b/api/services/app_service.py index 154360bb0b..8d8ba735ec 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import cast from flask_login import current_user @@ -155,10 +155,7 @@ class AppService: """ # get original app model config if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: - model_config: AppModelConfig | None = app.app_model_config - if not model_config: - return app - + model_config = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get("tools") or []: @@ -226,7 +223,7 @@ class AppService: app.icon_background = args.get("icon_background") app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() if app.max_active_requests is not None: @@ -243,7 +240,7 @@ class AppService: """ app.name = name app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -259,7 +256,7 @@ class AppService: app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -276,7 +273,7 @@ class AppService: app.enable_site = enable_site app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -293,7 +290,7 @@ class AppService: app.enable_api = enable_api app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py index 2d6e901447..2e1946841f 100644 --- a/api/services/auth/auth_type.py +++ b/api/services/auth/auth_type.py @@ -1,6 +1,6 @@ -from enum import Enum +from enum import StrEnum -class AuthType(str, Enum): +class AuthType(StrEnum): FIRECRAWL = "firecrawl" JINA = "jinareader" diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f9e41988c0..f3e76d3300 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import asc, desc, or_ @@ -104,7 +104,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return conversation diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 806dbdf8c5..d38729f31e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -600,7 +600,7 @@ class DocumentService: # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -1072,7 +1072,7 @@ class DocumentService: document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None - document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from document.doc_form = document_data["doc_form"] db.session.add(document) @@ -1409,8 +1409,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -1429,7 +1429,7 @@ class SegmentService: except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -1481,8 +1481,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -1508,7 +1508,7 @@ class SegmentService: logging.exception("create segment index failed") for segment_document in segment_data_list: segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -1526,7 +1526,7 @@ class SegmentService: if segment.enabled != action: if not action: segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.add(segment) db.session.commit() @@ -1585,10 +1585,10 @@ class SegmentService: segment.word_count = len(content) segment.tokens = tokens segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1608,7 +1608,7 @@ class SegmentService: except Exception as e: logging.exception("update segment index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 98e5d9face..7e3cd87f1e 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union import httpx @@ -99,7 +99,7 @@ class ExternalDatasetService: external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id - external_knowledge_api.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return external_knowledge_api diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 6fe2f97d82..797f5f9686 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from pydantic import BaseModel, ConfigDict @@ -22,7 +22,7 @@ class LimitationModel(BaseModel): limit: int = 0 -class LicenseStatus(str, Enum): +class LicenseStatus(StrEnum): NONE = "none" INACTIVE = "inactive" ACTIVE = "active" diff --git a/api/services/file_service.py b/api/services/file_service.py index 976111502c..b12b95ca13 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -77,7 +77,7 @@ class FileService: mime_type=mimetype, created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by=user.id, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=False, hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, @@ -123,10 +123,10 @@ class FileService: mime_type="text/plain", created_by=current_user.id, created_by_role=CreatedByRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) db.session.add(upload_file) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 5c5c0c1607..1d9a736540 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -371,7 +371,7 @@ class ModelLoadBalancingService: load_balancing_config.name = name load_balancing_config.enabled = enabled - load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() self._clear_credentials_cache(tenant_id, config_id) diff --git a/api/services/recommend_app/recommend_app_type.py b/api/services/recommend_app/recommend_app_type.py index 7ea93b3f64..e60e435b3a 100644 --- a/api/services/recommend_app/recommend_app_type.py +++ b/api/services/recommend_app/recommend_app_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class RecommendAppType(str, Enum): +class RecommendAppType(StrEnum): REMOTE = "remote" BUILDIN = "builtin" DATABASE = "db" diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 8cb3f9fe6e..9fd40c412d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,7 +1,7 @@ import json import time from collections.abc import Callable, Generator, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -118,7 +118,7 @@ class WorkflowService: workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables @@ -151,7 +151,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.now(timezone.utc).replace(tzinfo=None)), + version=str(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -312,18 +312,22 @@ class WorkflowService: workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None - workflow_node_execution.process_data = ( - json.dumps(node_run_result.process_data) if node_run_result.process_data else None - ) - workflow_node_execution.outputs = ( - json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + process_data = ( + WorkflowEntry.handle_special_values(node_run_result.process_data) + if node_run_result.process_data + else None ) + outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + + workflow_node_execution.inputs = json.dumps(inputs) + workflow_node_execution.process_data = json.dumps(process_data) + workflow_node_execution.outputs = json.dumps(outputs) workflow_node_execution.execution_metadata = ( json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) @@ -355,10 +359,10 @@ class WorkflowService: new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get("name", ""), - icon_type=args.get("icon_type", ""), - icon=args.get("icon", ""), - icon_background=args.get("icon_background", ""), + name=args.get("name", "Default Name"), + icon_type=args.get("icon_type", "emoji"), + icon=args.get("icon", "🤖"), + icon_background=args.get("icon_background", "#FFEAD5"), ) return new_app diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index b50876cc79..09be661216 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -18,9 +18,9 @@ from models.dataset import DocumentSegment def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index - :param document_id: + :param dataset_document_id: - Usage: add_document_to_index.delay(document_id) + Usage: add_document_to_index.delay(dataset_document_id) """ logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() @@ -74,7 +74,7 @@ def add_document_to_index_task(dataset_document_id: str): except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False - dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index e819bf3635..0bdcd0eccd 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -52,7 +52,7 @@ def enable_annotation_reply_task( annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) else: new_app_annotation_setting = AppAnnotationSetting( diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 5ee72c27fc..dcb7009e44 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -80,9 +80,9 @@ def batch_create_segment_to_index_task( word_count=len(content), tokens=tokens, created_by=user_id, - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), status="completed", - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) if dataset_document.doc_form == "qa_model": segment_document.answer = segment["answer"] diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 26375743b6..315b01f157 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -38,7 +38,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -75,7 +75,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -87,7 +87,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] except Exception as e: logging.exception("create segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 6dd755ab03..1831691393 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -67,7 +67,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index df1177d578..734dd2478a 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -50,7 +50,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -64,7 +64,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index cb38bc668d..1a52a6636b 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -30,7 +30,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1412ad9ec7..12639db939 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -71,7 +71,7 @@ def enable_segment_to_index_task(segment_id: str): except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index f6b3be8250..f6555cfdde 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator -from datetime import datetime, timezone +from datetime import UTC, datetime, timezone from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: - route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) parallel_id = graph.node_parallel_mapping.get(next_node_id) parallel_start_node_id = None @@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve ) route_node_state.status = RouteNodeState.Status.SUCCESS - route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None) yield NodeRunSucceededEvent( id=node_execution_id, node_id=next_node_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index def6c2a232..9a24d35a1f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,125 +1,431 @@ +from collections.abc import Sequence +from typing import Optional + import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.file import File, FileTransferMethod, FileType -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam -from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, + VisionConfigOptions, +) from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom +from models.provider import ProviderType from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario -class TestLLMNode: - @pytest.fixture - def llm_node(self): - data = LLMNodeData( - title="Test LLM", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[], - memory=None, - context=ContextConfig(enabled=False), - vision=VisionConfig( - enabled=True, - configs=VisionConfigOptions( - variable_selector=["sys", "files"], - detail=ImagePromptMessageContent.DETAIL.HIGH, - ), - ), - ) - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - node = LLMNode( - id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - return node +class MockTokenBufferMemory: + def __init__(self, history_messages=None): + self.history_messages = history_messages or [] - def test_fetch_files_with_file_segment(self, llm_node): - file = File( + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + if message_limit is not None: + return self.history_messages[-message_limit * 2 :] + return self.history_messages + + +@pytest.fixture +def llm_node(): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +@pytest.fixture +def model_config(): + # Create actual provider and model type instances + model_provider_factory = ModelProviderFactory() + provider_instance = model_provider_factory.get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + # Create a ProviderModelBundle + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=None), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + # Create and return a ModelConfigWithCredentialsEntity + return ModelConfigWithCredentialsEntity( + provider="openai", + model="gpt-3.5-turbo", + model_schema=AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ), + mode="chat", + credentials={}, + parameters={}, + provider_model_bundle=provider_model_bundle, + ) + + +def test_fetch_files_with_file_segment(llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + +def test_fetch_files_with_array_file_segment(llm_node): + files = [ + File( id="1", tenant_id="test", type=FileType.IMAGE, - filename="test.jpg", + filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + +def test_fetch_files_with_none_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_array_any_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_non_existent_variable(llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): + prompt_template = [] + llm_node.node_data.prompt_template = prompt_template + + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + ] - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [file] + fake_query = faker.sentence() - def test_fetch_files_with_array_file_segment(self, llm_node): - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1", - ), - File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="test2.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="2", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=None, + memory=None, + model_config=model_config, + prompt_template=prompt_template, + memory_config=None, + vision_enabled=False, + vision_detail=fake_vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == files + assert prompt_messages == [UserPromptMessage(content=fake_query)] - def test_fetch_files_with_none_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] +def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + # Setup dify config + dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" + dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" - def test_fetch_files_with_array_any_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + # Generate fake values for prompt template + fake_assistant_prompt = faker.sentence() + fake_query = faker.sentence() + fake_context = faker.sentence() + fake_window_size = faker.random_int(min=1, max=3) + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Setup mock memory with history messages + mock_history = [ + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + ] - def test_fetch_files_with_non_existent_variable(self, llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, + ) + + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Test scenarios covering different file input combinations + test_scenarios = [ + LLMNodeTestScenario( + description="No files", + user_query=fake_query, + user_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + ), + LLMNodeTestScenario( + description="User files", + user_query=fake_query, + user_files=[ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + ], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ], + ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + user_query=fake_query, + user_files=[], + vision_enabled=False, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), + ] + + for scenario in test_scenarios: + model_config.model_schema.features = scenario.features + + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=scenario.user_query, + user_files=scenario.user_files, + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario.prompt_template, + memory_config=memory_config, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + + # Verify the result + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" + assert ( + prompt_messages == scenario.expected_messages + ), f"Message content mismatch in scenario: {scenario.description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 0000000000..8e39445baf --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,25 @@ +from collections.abc import Mapping, Sequence + +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + user_query: str = Field(..., description="User query input") + user_files: Sequence[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py deleted file mode 100644 index 842e8268d1..0000000000 --- a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from packaging import version - -from services.app_dsl_service import AppDslService -from services.app_dsl_service.exc import DSLVersionNotSupportedError -from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version - - -class TestAppDSLService: - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_missing_version(self): - import_data = {} - result = _check_or_fix_dsl(import_data) - assert result["version"] == "0.1.0" - assert result["kind"] == "app" - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_missing_kind(self): - import_data = {"version": "0.1.0"} - result = _check_or_fix_dsl(import_data) - assert result["kind"] == "app" - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_older_version(self): - import_data = {"version": "0.0.9", "kind": "app"} - result = _check_or_fix_dsl(import_data) - assert result["version"] == "0.0.9" - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_current_version(self): - import_data = {"version": current_dsl_version, "kind": "app"} - result = _check_or_fix_dsl(import_data) - assert result["version"] == current_dsl_version - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_newer_version(self): - current_version = version.parse(current_dsl_version) - newer_version = f"{current_version.major}.{current_version.minor + 1}.0" - import_data = {"version": newer_version, "kind": "app"} - with pytest.raises(DSLVersionNotSupportedError): - _check_or_fix_dsl(import_data) - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_invalid_kind(self): - import_data = {"version": current_dsl_version, "kind": "invalid"} - result = _check_or_fix_dsl(import_data) - assert result["kind"] == "app" diff --git a/docker/.env.example b/docker/.env.example index d29c66535d..50dc56a5c9 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -574,9 +574,11 @@ UPLOAD_FILE_BATCH_LIMIT=5 # `Unstructured` Unstructured.io file extraction scheme ETL_TYPE=dify -# Unstructured API path, needs to be configured when ETL_TYPE is Unstructured. +# Unstructured API path and API key, needs to be configured when ETL_TYPE is Unstructured +# Or using Unstructured for document extractor node for pptx. # For example: http://unstructured:8000/general/v0/general UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= # ------------------------------ # Model Configuration diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b47573bfc8..285f576b0e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -177,7 +177,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} - LINDORM_PASSWORD: ${LINDORM_USERNAME:-lindorm } + LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm } KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} @@ -222,6 +222,7 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} + UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index 23f00d46d8..f30d3e4a0a 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -12,34 +12,46 @@ import ConfigContext from '@/context/debug-configuration' // import { Resolution } from '@/types/app' import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' import Switch from '@/app/components/base/switch' -import type { FileUpload } from '@/app/components/base/features/types' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' const ConfigVision: FC = () => { const { t } = useTranslation() - const { isShowVisionConfig } = useContext(ConfigContext) + const { isShowVisionConfig, isAllowVideoUpload } = useContext(ConfigContext) const file = useFeatures(s => s.features.file) const featuresStore = useFeaturesStore() - const handleChange = useCallback((data: FileUpload) => { + const isImageEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.image) ?? false + + const handleChange = useCallback((value: boolean) => { const { features, setFeatures, } = featuresStore!.getState() const newFeatures = produce(features, (draft) => { - draft.file = { - ...draft.file, - enabled: data.enabled, - image: { - enabled: data.enabled, - detail: data.image?.detail, - transfer_methods: data.image?.transfer_methods, - number_limits: data.image?.number_limits, - }, + if (value) { + draft.file!.allowed_file_types = Array.from(new Set([ + ...(draft.file?.allowed_file_types || []), + SupportUploadFileTypes.image, + ...(isAllowVideoUpload ? [SupportUploadFileTypes.video] : []), + ])) + } + else { + draft.file!.allowed_file_types = draft.file!.allowed_file_types?.filter( + type => type !== SupportUploadFileTypes.image && (isAllowVideoUpload ? type !== SupportUploadFileTypes.video : true), + ) + } + + if (draft.file) { + draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 + draft.file.image = { + ...(draft.file.image || {}), + enabled: value, + } } }) setFeatures(newFeatures) - }, [featuresStore]) + }, [featuresStore, isAllowVideoUpload]) if (!isShowVisionConfig) return null @@ -89,11 +101,8 @@ const ConfigVision: FC = () => {
handleChange({ - ...(file || {}), - enabled: value, - })} + defaultValue={isImageEnabled} + onChange={handleChange} size='md' /> diff --git a/web/app/components/app/configuration/config/config-document.tsx b/web/app/components/app/configuration/config/config-document.tsx new file mode 100644 index 0000000000..1ac6da0dd8 --- /dev/null +++ b/web/app/components/app/configuration/config/config-document.tsx @@ -0,0 +1,78 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { useContext } from 'use-context-selector' + +import { Document } from '@/app/components/base/icons/src/vender/features' +import Tooltip from '@/app/components/base/tooltip' +import ConfigContext from '@/context/debug-configuration' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import Switch from '@/app/components/base/switch' + +const ConfigDocument: FC = () => { + const { t } = useTranslation() + const file = useFeatures(s => s.features.file) + const featuresStore = useFeaturesStore() + const { isShowDocumentConfig } = useContext(ConfigContext) + + const isDocumentEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.document) ?? false + + const handleChange = useCallback((value: boolean) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + if (value) { + draft.file!.allowed_file_types = Array.from(new Set([ + ...(draft.file?.allowed_file_types || []), + SupportUploadFileTypes.document, + ])) + } + else { + draft.file!.allowed_file_types = draft.file!.allowed_file_types?.filter( + type => type !== SupportUploadFileTypes.document, + ) + } + if (draft.file) + draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 + }) + setFeatures(newFeatures) + }, [featuresStore]) + + if (!isShowDocumentConfig) + return null + + return ( +
+
+
+ +
+
+
+
{t('appDebug.feature.documentUpload.title')}
+ + {t('appDebug.feature.documentUpload.description')} +
+ } + /> +
+
+
+ +
+ + ) +} +export default React.memo(ConfigDocument) diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx index 8687079931..39fdd502ef 100644 --- a/web/app/components/app/configuration/config/index.tsx +++ b/web/app/components/app/configuration/config/index.tsx @@ -7,6 +7,7 @@ import { useFormattingChangedDispatcher } from '../debug/hooks' import DatasetConfig from '../dataset-config' import HistoryPanel from '../config-prompt/conversation-history/history-panel' import ConfigVision from '../config-vision' +import ConfigDocument from './config-document' import AgentTools from './agent/agent-tools' import ConfigContext from '@/context/debug-configuration' import ConfigPrompt from '@/app/components/app/configuration/config-prompt' @@ -82,6 +83,8 @@ const Config: FC = () => { + + {/* Chat History */} {isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && ( { } const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision) - + const isShowDocumentConfig = !!currModel?.features?.includes(ModelFeatureEnum.document) + const isAllowVideoUpload = !!currModel?.features?.includes(ModelFeatureEnum.video) // *** web app features *** const featuresData: FeaturesData = useMemo(() => { return { @@ -472,7 +473,7 @@ const Configuration: FC = () => { transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], }, enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), - allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video], + allowed_file_types: modelConfig.file_upload?.allowed_file_types || [], allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`), allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, @@ -861,6 +862,8 @@ const Configuration: FC = () => { isShowVisionConfig, visionConfig, setVisionConfig: handleSetVisionConfig, + isAllowVideoUpload, + isShowDocumentConfig, rerankSettingModalOpen, setRerankSettingModalOpen, }} diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index e238ce0e91..ce06b113bc 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -12,9 +12,13 @@ import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { ToastContext } from '@/app/components/base/toast' import { - importApp, - importAppFromUrl, + importDSL, + importDSLConfirm, } from '@/service/apps' +import { + DSLImportMode, + DSLImportStatus, +} from '@/models/app' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' @@ -43,6 +47,9 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const [fileContent, setFileContent] = useState() const [currentTab, setCurrentTab] = useState(activeTab) const [dslUrlValue, setDslUrlValue] = useState(dslUrl) + const [showErrorModal, setShowErrorModal] = useState(false) + const [versions, setVersions] = useState<{ importedVersion: string; systemVersion: string }>() + const [importId, setImportId] = useState() const readFile = (file: File) => { const reader = new FileReader() @@ -66,6 +73,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) const isCreatingRef = useRef(false) + const onCreate: MouseEventHandler = async () => { if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) return @@ -75,25 +83,54 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS return isCreatingRef.current = true try { - let app + let response if (currentTab === CreateFromDSLModalTab.FROM_FILE) { - app = await importApp({ - data: fileContent || '', + response = await importDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: fileContent || '', }) } if (currentTab === CreateFromDSLModalTab.FROM_URL) { - app = await importAppFromUrl({ - url: dslUrlValue || '', + response = await importDSL({ + mode: DSLImportMode.YAML_URL, + yaml_url: dslUrlValue || '', }) } - if (onSuccess) - onSuccess() - if (onClose) - onClose() - notify({ type: 'success', message: t('app.newApp.appCreated') }) - localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - getRedirection(isCurrentWorkspaceEditor, app, push) + + if (!response) + return + + const { id, status, app_id, imported_dsl_version, current_dsl_version } = response + if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { + if (onSuccess) + onSuccess() + if (onClose) + onClose() + + notify({ + type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', + message: t(status === DSLImportStatus.COMPLETED ? 'app.newApp.appCreated' : 'app.newApp.caution'), + children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('app.newApp.appCreateDSLWarning'), + }) + localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + getRedirection(isCurrentWorkspaceEditor, { id: app_id }, push) + } + else if (status === DSLImportStatus.PENDING) { + setVersions({ + importedVersion: imported_dsl_version ?? '', + systemVersion: current_dsl_version ?? '', + }) + if (onClose) + onClose() + setTimeout(() => { + setShowErrorModal(true) + }, 300) + setImportId(id) + } + else { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } } catch (e) { notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) @@ -101,6 +138,38 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS isCreatingRef.current = false } + const onDSLConfirm: MouseEventHandler = async () => { + try { + if (!importId) + return + const response = await importDSLConfirm({ + import_id: importId, + }) + + const { status, app_id } = response + + if (status === DSLImportStatus.COMPLETED) { + if (onSuccess) + onSuccess() + if (onClose) + onClose() + + notify({ + type: 'success', + message: t('app.newApp.appCreated'), + }) + localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + getRedirection(isCurrentWorkspaceEditor, { id: app_id }, push) + } + else if (status === DSLImportStatus.FAILED) { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } + } + catch (e) { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } + } + const tabs = [ { key: CreateFromDSLModalTab.FROM_FILE, @@ -123,74 +192,96 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS }, [isAppsFull, currentTab, currentFile, dslUrlValue]) return ( - { }} - > -
- {t('app.importFromDSL')} -
onClose()} - > - + <> + { }} + > +
+ {t('app.importFromDSL')} +
onClose()} + > + +
-
-
- { - tabs.map(tab => ( -
setCurrentTab(tab.key)} - > - {tab.label} - { - currentTab === tab.key && ( -
- ) - } -
- )) - } -
-
- { - currentTab === CreateFromDSLModalTab.FROM_FILE && ( - - ) - } - { - currentTab === CreateFromDSLModalTab.FROM_URL && ( -
-
DSL URL
- setDslUrlValue(e.target.value)} +
+ { + tabs.map(tab => ( +
setCurrentTab(tab.key)} + > + {tab.label} + { + currentTab === tab.key && ( +
+ ) + } +
+ )) + } +
+
+ { + currentTab === CreateFromDSLModalTab.FROM_FILE && ( + -
- ) - } -
- {isAppsFull && ( -
- + ) + } + { + currentTab === CreateFromDSLModalTab.FROM_URL && ( +
+
DSL URL
+ setDslUrlValue(e.target.value)} + /> +
+ ) + }
- )} -
- - -
- + {isAppsFull && ( +
+ +
+ )} +
+ + +
+ + setShowErrorModal(false)} + className='w-[480px]' + > +
+
{t('app.newApp.appCreateDSLErrorTitle')}
+
+
{t('app.newApp.appCreateDSLErrorPart1')}
+
{t('app.newApp.appCreateDSLErrorPart2')}
+
+
{t('app.newApp.appCreateDSLErrorPart3')}{versions?.importedVersion}
+
{t('app.newApp.appCreateDSLErrorPart4')}{versions?.systemVersion}
+
+
+
+ + +
+
+ ) } diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index fa5554f9cf..beb2b4b1a0 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -6,6 +6,7 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' +import { formatFileSize } from '@/utils/format' import cn from '@/utils/classnames' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' @@ -58,8 +59,13 @@ const Uploader: FC = ({ updateFile(files[0]) } const selectHandle = () => { - if (fileUploader.current) + const originalFile = file + if (fileUploader.current) { + fileUploader.current.value = '' fileUploader.current.click() + // If no file is selected, restore the original file + fileUploader.current.oncancel = () => updateFile(originalFile) + } } const removeFile = () => { if (fileUploader.current) @@ -96,7 +102,7 @@ const Uploader: FC = ({ />
{!file && ( -
+
@@ -108,17 +114,23 @@ const Uploader: FC = ({
)} {file && ( -
- -
- {file.name.replace(/(.yaml|.yml)$/, '')} - .yml +
+
+ +
+
+ {file.name} +
+ YAML + · + {formatFileSize(file.size)} +
- +
diff --git a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap index 7da09c4529..4ffcfa31e9 100644 --- a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap +++ b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap @@ -1804,8 +1804,85 @@ exports[`build chat item tree and get thread messages should get thread messages ] `; -exports[`build chat item tree and get thread messages should work with partial messages 1`] = ` +exports[`build chat item tree and get thread messages should work with partial messages 1 1`] = ` [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, { "children": [ { @@ -2078,6 +2155,178 @@ exports[`build chat item tree and get thread messages should work with partial m ] `; +exports[`build chat item tree and get thread messages should work with partial messages 2 1`] = ` +[ + { + "children": [ + { + "children": [], + "content": "237.", + "id": "ebb73fe2-15de-46dd-aab5-75416d8448eb", + "isAnswer": true, + "parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb", + "siblingIndex": 0, + }, + ], + "content": "123", + "id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418", + }, + { + "children": [ + { + "children": [], + "content": "My number is 256.", + "id": "3553d508-3850-462e-8594-078539f940f9", + "isAnswer": true, + "parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9", + "siblingIndex": 1, + }, + ], + "content": "123", + "id": "question-3553d508-3850-462e-8594-078539f940f9", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418", + }, + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [], + "content": "My number is 3e (approximately 8.15).", + "id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "isAnswer": true, + "parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "siblingIndex": 0, + }, + ], + "content": "e", + "id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "isAnswer": false, + "parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + }, + ], + "content": "My number is 2π (approximately 6.28).", + "id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "isAnswer": true, + "parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "siblingIndex": 0, + }, + ], + "content": "π", + "id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "isAnswer": false, + "parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f", + }, + ], + "content": "My number is 145.", + "id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "isAnswer": true, + "parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "siblingIndex": 0, + }, + ], + "content": "78", + "id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "isAnswer": false, + "parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c", + }, + ], + "content": "My number is 7.89.", + "id": "3cded945-855a-4a24-aab7-43c7dd54664c", + "isAnswer": true, + "parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c", + "siblingIndex": 0, + }, + ], + "content": "3.11", + "id": "question-3cded945-855a-4a24-aab7-43c7dd54664c", + "isAnswer": false, + "parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + }, + ], + "content": "My number is 22.", + "id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "isAnswer": true, + "parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "siblingIndex": 0, + }, + ], + "content": "-5", + "id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "isAnswer": false, + "parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55", + }, + ], + "content": "My number is 4782.", + "id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "isAnswer": true, + "parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "siblingIndex": 0, + }, + ], + "content": "3306", + "id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "isAnswer": false, + "parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f", + }, + ], + "content": "My number is 2048.", + "id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "isAnswer": true, + "parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "siblingIndex": 0, + }, + ], + "content": "1024", + "id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "isAnswer": false, + "parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22", + }, + ], + "content": "My number is 259.", + "id": "507f9df9-1f06-4a57-bb38-f00228c42c22", + "isAnswer": true, + "parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22", + "siblingIndex": 2, + }, + ], + "content": "123", + "id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418", + }, +] +`; + exports[`build chat item tree and get thread messages should work with real world messages 1`] = ` [ { diff --git a/web/app/components/base/chat/__tests__/partialMessages.json b/web/app/components/base/chat/__tests__/partialMessages.json new file mode 100644 index 0000000000..916c6ad254 --- /dev/null +++ b/web/app/components/base/chat/__tests__/partialMessages.json @@ -0,0 +1,122 @@ +[ + { + "id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb", + "content": "123", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418" + }, + { + "id": "ebb73fe2-15de-46dd-aab5-75416d8448eb", + "content": "237.", + "isAnswer": true, + "parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb" + }, + { + "id": "question-3553d508-3850-462e-8594-078539f940f9", + "content": "123", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418" + }, + { + "id": "3553d508-3850-462e-8594-078539f940f9", + "content": "My number is 256.", + "isAnswer": true, + "parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9" + }, + { + "id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22", + "content": "123", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418" + }, + { + "id": "507f9df9-1f06-4a57-bb38-f00228c42c22", + "content": "My number is 259.", + "isAnswer": true, + "parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22" + }, + { + "id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "content": "1024", + "isAnswer": false, + "parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22" + }, + { + "id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "content": "My number is 2048.", + "isAnswer": true, + "parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f" + }, + { + "id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "content": "3306", + "isAnswer": false, + "parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f" + }, + { + "id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "content": "My number is 4782.", + "isAnswer": true, + "parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55" + }, + { + "id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "content": "-5", + "isAnswer": false, + "parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55" + }, + { + "id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "content": "My number is 22.", + "isAnswer": true, + "parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7" + }, + { + "id": "question-3cded945-855a-4a24-aab7-43c7dd54664c", + "content": "3.11", + "isAnswer": false, + "parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7" + }, + { + "id": "3cded945-855a-4a24-aab7-43c7dd54664c", + "content": "My number is 7.89.", + "isAnswer": true, + "parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c" + }, + { + "id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "content": "78", + "isAnswer": false, + "parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c" + }, + { + "id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "content": "My number is 145.", + "isAnswer": true, + "parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f" + }, + { + "id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "content": "π", + "isAnswer": false, + "parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f" + }, + { + "id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "content": "My number is 2π (approximately 6.28).", + "isAnswer": true, + "parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c" + }, + { + "id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "content": "e", + "isAnswer": false, + "parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c" + }, + { + "id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "content": "My number is 3e (approximately 8.15).", + "isAnswer": true, + "parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719" + } +] diff --git a/web/app/components/base/chat/__tests__/utils.spec.ts b/web/app/components/base/chat/__tests__/utils.spec.ts index 1dead1c949..0bff8a77a1 100644 --- a/web/app/components/base/chat/__tests__/utils.spec.ts +++ b/web/app/components/base/chat/__tests__/utils.spec.ts @@ -7,6 +7,7 @@ import mixedTestMessages from './mixedTestMessages.json' import multiRootNodesMessages from './multiRootNodesMessages.json' import multiRootNodesWithLegacyTestMessages from './multiRootNodesWithLegacyTestMessages.json' import realWorldMessages from './realWorldMessages.json' +import partialMessages from './partialMessages.json' function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatItemInTree { return get(tree, path) @@ -256,9 +257,15 @@ describe('build chat item tree and get thread messages', () => { expect(threadMessages6_2).toMatchSnapshot() }) - const partialMessages = (realWorldMessages as ChatItemInTree[]).slice(-10) - const tree7 = buildChatItemTree(partialMessages) - it('should work with partial messages', () => { + const partialMessages1 = (realWorldMessages as ChatItemInTree[]).slice(-10) + const tree7 = buildChatItemTree(partialMessages1) + it('should work with partial messages 1', () => { expect(tree7).toMatchSnapshot() }) + + const partialMessages2 = (partialMessages as ChatItemInTree[]) + const tree8 = buildChatItemTree(partialMessages2) + it('should work with partial messages 2', () => { + expect(tree8).toMatchSnapshot() + }) }) diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 5169e65a59..eec636b478 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -102,21 +102,21 @@ const ChatInputArea = ({ setCurrentIndex(historyRef.current.length) handleSend() } - else if (e.key === 'ArrowUp' && !e.shiftKey && !e.nativeEvent.isComposing) { - // When the up key is pressed, output the previous element + else if (e.key === 'ArrowUp' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) { + // When the cmd + up key is pressed, output the previous element if (currentIndex > 0) { setCurrentIndex(currentIndex - 1) setQuery(historyRef.current[currentIndex - 1]) } } - else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing) { - // When the down key is pressed, output the next element + else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) { + // When the cmd + down key is pressed, output the next element if (currentIndex < historyRef.current.length - 1) { setCurrentIndex(currentIndex + 1) setQuery(historyRef.current[currentIndex + 1]) } else if (currentIndex === historyRef.current.length - 1) { - // If it is the last element, clear the input box + // If it is the last element, clear the input box setCurrentIndex(historyRef.current.length) setQuery('') } diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index 61dfaecffc..326805c930 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -127,19 +127,16 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { lastAppendedLegacyAnswer = answerNode } else { - if (!parentMessageId) + if ( + !parentMessageId + || !allMessages.some(item => item.id === parentMessageId) // parent message might not be fetched yet, in this case we will append the question to the root nodes + ) rootNodes.push(questionNode) else map[parentMessageId]?.children!.push(questionNode) } } - // If no messages have parentMessageId=null (indicating a root node), - // then we likely have a partial chat history. In this case, - // use the first available message as the root node. - if (rootNodes.length === 0 && allMessages.length > 0) - rootNodes.push(map[allMessages[0]!.id]!) - return rootNodes } diff --git a/web/app/components/base/divider/index.tsx b/web/app/components/base/divider/index.tsx index 85ce886199..4b351dea99 100644 --- a/web/app/components/base/divider/index.tsx +++ b/web/app/components/base/divider/index.tsx @@ -1,17 +1,31 @@ import type { CSSProperties, FC } from 'react' import React from 'react' -import s from './style.module.css' +import { type VariantProps, cva } from 'class-variance-authority' +import classNames from '@/utils/classnames' -type Props = { - type?: 'horizontal' | 'vertical' - // orientation?: 'left' | 'right' | 'center' +const dividerVariants = cva( + 'bg-divider-regular', + { + variants: { + type: { + horizontal: 'w-full h-[0.5px] my-2', + vertical: 'w-[1px] h-full mx-2', + }, + }, + defaultVariants: { + type: 'horizontal', + }, + }, +) + +type DividerProps = { className?: string style?: CSSProperties -} +} & VariantProps -const Divider: FC = ({ type = 'horizontal', className = '', style }) => { +const Divider: FC = ({ type, className = '', style }) => { return ( -
+
) } diff --git a/web/app/components/base/divider/style.module.css b/web/app/components/base/divider/style.module.css deleted file mode 100644 index 9cb2601b1f..0000000000 --- a/web/app/components/base/divider/style.module.css +++ /dev/null @@ -1,9 +0,0 @@ -.divider { - @apply bg-gray-200; -} -.horizontal { - @apply w-full h-[0.5px] my-2; -} -.vertical { - @apply w-[1px] h-full mx-2; -} diff --git a/web/app/components/base/icons/assets/vender/features/document.svg b/web/app/components/base/icons/assets/vender/features/document.svg new file mode 100644 index 0000000000..dca0e91a52 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/features/document.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/features/Document.json b/web/app/components/base/icons/src/vender/features/Document.json new file mode 100644 index 0000000000..fdd08d5254 --- /dev/null +++ b/web/app/components/base/icons/src/vender/features/Document.json @@ -0,0 +1,23 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "viewBox": "0 0 24 24", + "fill": "currentColor" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M20 22H4C3.44772 22 3 21.5523 3 21V3C3 2.44772 3.44772 2 4 2H20C20.5523 2 21 2.44772 21 3V21C21 21.5523 20.5523 22 20 22ZM7 6V10H11V6H7ZM7 12V14H17V12H7ZM7 16V18H17V16H7ZM13 7V9H17V7H13Z" + }, + "children": [] + } + ] + }, + "name": "Document" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/features/Document.tsx b/web/app/components/base/icons/src/vender/features/Document.tsx new file mode 100644 index 0000000000..84bf3a2f10 --- /dev/null +++ b/web/app/components/base/icons/src/vender/features/Document.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Document.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'Document' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/features/index.ts b/web/app/components/base/icons/src/vender/features/index.ts index 2b8cb17c94..f246732226 100644 --- a/web/app/components/base/icons/src/vender/features/index.ts +++ b/web/app/components/base/icons/src/vender/features/index.ts @@ -7,3 +7,4 @@ export { default as Microphone01 } from './Microphone01' export { default as TextToAudio } from './TextToAudio' export { default as VirtualAssistant } from './VirtualAssistant' export { default as Vision } from './Vision' +export { default as Document } from './Document' diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index 3e13db5d7f..b1b9ffe8c4 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -3,16 +3,19 @@ import type { ReactNode } from 'react' import React, { useEffect, useState } from 'react' import { createRoot } from 'react-dom/client' import { - CheckCircleIcon, - ExclamationTriangleIcon, - InformationCircleIcon, - XCircleIcon, -} from '@heroicons/react/20/solid' + RiAlertFill, + RiCheckboxCircleFill, + RiCloseLine, + RiErrorWarningFill, + RiInformation2Fill, +} from '@remixicon/react' import { createContext, useContext } from 'use-context-selector' +import ActionButton from '@/app/components/base/action-button' import classNames from '@/utils/classnames' export type IToastProps = { type?: 'success' | 'error' | 'warning' | 'info' + size?: 'md' | 'sm' duration?: number message: string children?: ReactNode @@ -21,60 +24,55 @@ export type IToastProps = { } type IToastContext = { notify: (props: IToastProps) => void + close: () => void } export const ToastContext = createContext({} as IToastContext) export const useToastContext = () => useContext(ToastContext) const Toast = ({ type = 'info', + size = 'md', message, children, className, }: IToastProps) => { + const { close } = useToastContext() // sometimes message is react node array. Not handle it. if (typeof message !== 'string') return null return
-
-
- {type === 'success' &&