Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly 2024-11-25 15:37:19 +08:00
commit 5ff9cee326
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
272 changed files with 7635 additions and 2740 deletions

View File

@ -1,5 +1,5 @@
FROM mcr.microsoft.com/devcontainers/python:3.10 FROM mcr.microsoft.com/devcontainers/python:3.12
# [Optional] Uncomment this section to install additional OS packages. # [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here> # && apt-get -y install --no-install-recommends <your-package-list-here>

View File

@ -1,7 +1,7 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the // For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{ {
"name": "Python 3.10", "name": "Python 3.12",
"build": { "build": {
"context": "..", "context": "..",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"

View File

@ -4,7 +4,7 @@ inputs:
python-version: python-version:
description: Python version to use and the Poetry installed with description: Python version to use and the Poetry installed with
required: true required: true
default: '3.10' default: '3.11'
poetry-version: poetry-version:
description: Poetry version to set up description: Poetry version to set up
required: true required: true

View File

@ -20,7 +20,6 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"

View File

@ -20,7 +20,6 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"

View File

@ -1,6 +1,8 @@
# CONTRIBUTING
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly. So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve. This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests: ### Feature requests
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try. * If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
* If you want to pick one up from the existing issues, simply drop a comment below it saying so. * If you want to pick one up from the existing issues, simply drop a comment below it saying so.
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes. A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment: Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-core features and minor enhancements | Low Priority | | Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature | | Valuable but not immediate | Future-Feature |
### Anything else (e.g. bug report, performance optimization, typo correction): ### Anything else (e.g. bug report, performance optimization, typo correction)
* Start coding right away. * Start coding right away.
@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-critical bugs, performance boosts | Medium Priority | | Non-critical bugs, performance boosts | Medium Priority |
| Minor fixes (typos, confusing but working UI) | Low Priority | | Minor fixes (typos, confusing but working UI) | Low Priority |
## Installing ## Installing
Here are the steps to set up Dify for development: Here are the steps to set up Dify for development:
@ -63,7 +62,7 @@ Here are the steps to set up Dify for development:
Clone the forked repository from your terminal: Clone the forked repository from your terminal:
``` ```shell
git clone git@github.com:<github_username>/dify.git git clone git@github.com:<github_username>/dify.git
``` ```
@ -71,11 +70,11 @@ git clone git@github.com:<github_username>/dify.git
Dify requires the following dependencies to build, make sure they're installed on your system: Dify requires the following dependencies to build, make sure they're installed on your system:
- [Docker](https://www.docker.com/) * [Docker](https://www.docker.com/)
- [Docker Compose](https://docs.docker.com/compose/install/) * [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) * [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) * [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x * [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. Installations ### 4. Installations
@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo
### 5. Visit dify in your browser ### 5. Visit dify in your browser
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
## Developing ## Developing
@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou
### Backend ### Backend
Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
``` ```text
[api/] [api/]
├── constants // Constant settings used throughout code base. ├── constants // Constant settings used throughout code base.
├── controllers // API route definitions and request handling logic. ├── controllers // API route definitions and request handling logic.
@ -121,7 +120,7 @@ Difys backend is written in Python using [Flask](https://flask.palletsproject
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization. The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
``` ```text
[web/] [web/]
├── app // layouts, pages, and components ├── app // layouts, pages, and components
│ ├── (commonLayout) // common layout used throughout the app │ ├── (commonLayout) // common layout used throughout the app
@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ
## Submitting your PR ## Submitting your PR
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md). And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
## Getting Help ## Getting Help
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.

View File

@ -71,7 +71,7 @@ Dify 依赖以下工具和库:
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. 安装 ### 4. 安装

View File

@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. インストール ### 4. インストール

View File

@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) phiên bản 3.10.x - [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
### 4. Cài đặt ### 4. Cài đặt
@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ
## Nhận trợ giúp ## Nhận trợ giúp
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.

View File

@ -1,5 +1,5 @@
# base image # base image
FROM python:3.10-slim-bookworm AS base FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api

View File

@ -42,7 +42,7 @@
5. Install dependencies 5. Install dependencies
```bash ```bash
poetry env use 3.10 poetry env use 3.12
poetry install poetry install
``` ```
@ -81,5 +81,3 @@
```bash ```bash
poetry run -C api bash dev/pytest/pytest_all_tests.sh poetry run -C api bash dev/pytest/pytest_all_tests.sh
``` ```

View File

@ -1,6 +1,11 @@
import os import os
import sys import sys
python_version = sys.version_info
if not ((3, 11) <= python_version < (3, 13)):
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
raise SystemExit(1)
from configs import dify_config from configs import dify_config
if not dify_config.DEBUG: if not dify_config.DEBUG:
@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
if sys.version_info[:2] == (3, 10):
print("Warning: Python 3.10 will not be supported in the next version.")
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)

View File

@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file # read from dotenv format config file
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes # ignore extra attributes
extra="ignore", extra="ignore",
) )

View File

@ -2,6 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload") api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version

View File

@ -1,7 +1,10 @@
import uuid import uuid
from typing import cast
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api from controllers.console import api
@ -12,15 +15,16 @@ from controllers.console.wraps import (
enterprise_license_required, enterprise_license_required,
setup_required, setup_required,
) )
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import (
app_detail_fields, app_detail_fields,
app_detail_fields_with_site, app_detail_fields_with_site,
app_pagination_fields, app_pagination_fields,
) )
from libs.login import login_required from libs.login import login_required
from services.app_dsl_service import AppDslService from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -93,99 +97,6 @@ class AppListApi(Resource):
return app, 201 return app, 201
class AppImportDependenciesCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Check dependencies"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
leaked_dependencies = AppDslService.check_dependencies(
tenant_id=current_user.current_tenant_id, data=args["data"], account=current_user
)
return jsonable_encoder({"leaked": leaked_dependencies}), 200
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
)
return app, 201
class AppImportFromUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
)
return app, 201
class AppImportFromUrlDependenciesCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
leaked_dependencies = AppDslService.check_dependencies_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], account=current_user
)
return jsonable_encoder({"leaked": leaked_dependencies}), 200
class AppApi(Resource): class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -263,10 +174,24 @@ class AppCopyApi(Resource):
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True) with Session(db.engine) as session:
app = AppDslService.import_and_create_new_app( import_service = AppDslService(session)
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
) account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
stmt = select(App).where(App.id == result.app.id)
app = session.scalar(stmt)
return app, 201 return app, 201
@ -407,10 +332,6 @@ class AppTraceApi(Resource):
api.add_resource(AppListApi, "/apps") api.add_resource(AppListApi, "/apps")
api.add_resource(AppImportDependenciesCheckApi, "/apps/import/dependencies/check")
api.add_resource(AppImportApi, "/apps/import")
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
api.add_resource(AppImportFromUrlDependenciesCheckApi, "/apps/import/url/dependencies/check")
api.add_resource(AppApi, "/apps/<uuid:app_id>") api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy") api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export") api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")

View File

@ -0,0 +1,90 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from libs.login import login_required
from models import Account
from services.app_dsl_service import AppDslService, ImportStatus
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
import pytz import pytz
from flask_login import current_user from flask_login import current_user
@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if not conversation.read_at: if not conversation.read_at:
conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
conversation.read_account_id = current_user.id conversation.read_account_id = current_user.id
db.session.commit() db.session.commit()

View File

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
@ -75,7 +75,7 @@ class AppSite(Resource):
setattr(site, attr_name, value) setattr(site, attr_name, value)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site
@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource):
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site

View File

@ -21,7 +21,6 @@ from libs.login import current_user, login_required
from models import App from models import App
from models.account import Account from models.account import Account
from models.model import AppMode from models.model import AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -130,34 +129,6 @@ class DraftWorkflowApi(Resource):
} }
class DraftWorkflowImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def post(self, app_model: App):
"""
Import draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model, data=args["data"], account=current_user
)
return workflow
class AdvancedChatDraftWorkflowRunApi(Resource): class AdvancedChatDraftWorkflowRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -490,7 +461,6 @@ class ConvertToWorkflowApi(Resource):
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")

View File

@ -65,7 +65,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))

View File

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional from typing import Optional
import requests import requests
@ -108,7 +108,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

View File

@ -88,7 +88,7 @@ class DataSourceApi(Resource):
if action == "enable": if action == "enable":
if data_source_binding.disabled: if data_source_binding.disabled:
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
@ -97,7 +97,7 @@ class DataSourceApi(Resource):
if action == "disable": if action == "disable":
if not data_source_binding.disabled: if not data_source_binding.disabled:
data_source_binding.disabled = True data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:

View File

@ -1,6 +1,6 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from datetime import datetime, timezone from datetime import UTC, datetime
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@ -681,7 +681,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.") raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id document.paused_by = current_user.id
document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) document.paused_at = datetime.now(UTC).replace(tzinfo=None)
document.is_paused = True document.is_paused = True
db.session.commit() db.session.commit()
@ -761,7 +761,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value document.doc_metadata[key] = value
document.doc_type = doc_type document.doc_type = doc_type
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200 return {"result": "success", "message": "Document metadata updated."}, 200
@ -803,7 +803,7 @@ class DocumentStatusApi(DocumentResource):
document.enabled = True document.enabled = True
document.disabled_at = None document.disabled_at = None
document.disabled_by = None document.disabled_by = None
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
@ -820,9 +820,9 @@ class DocumentStatusApi(DocumentResource):
raise InvalidActionError("Document already disabled.") raise InvalidActionError("Document already disabled.")
document.enabled = False document.enabled = False
document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id document.disabled_by = current_user.id
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
@ -837,9 +837,9 @@ class DocumentStatusApi(DocumentResource):
raise InvalidActionError("Document already archived.") raise InvalidActionError("Document already archived.")
document.archived = True document.archived = True
document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id document.archived_by = current_user.id
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
if document.enabled: if document.enabled:
@ -856,7 +856,7 @@ class DocumentStatusApi(DocumentResource):
document.archived = False document.archived = False
document.archived_at = None document.archived_at = None
document.archived_by = None document.archived_by = None
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times

View File

@ -1,5 +1,5 @@
import uuid import uuid
from datetime import datetime, timezone from datetime import UTC, datetime
import pandas as pd import pandas as pd
from flask import request from flask import request
@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError("Segment is already disabled.") raise InvalidActionError("Segment is already disabled.")
segment.enabled = False segment.enabled = False
segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id segment.disabled_by = current_user.id
db.session.commit() db.session.commit()

View File

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import reqparse from flask_restful import reqparse
@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:
@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

View File

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse from flask_restful import Resource, inputs, marshal_with, reqparse
@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), last_used_at=datetime.now(UTC).replace(tzinfo=None),
) )
db.session.add(new_installed_app) db.session.add(new_installed_app)
db.session.commit() db.session.commit()

View File

@ -60,7 +60,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError() raise InvalidInvitationCodeError()
invitation_code.status = "used" invitation_code.status = "used"
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id invitation_code.used_by_account_id = account.id
@ -68,7 +68,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = "active" account.status = "active"
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} return {"result": "success"}

View File

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None):
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return api_token return api_token

View File

@ -1,3 +1,4 @@
import uuid
from typing import Optional from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity

View File

@ -12,7 +12,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -39,27 +39,23 @@ class ModelConfigConverter:
) )
if model_credentials is None: if model_credentials is None:
if not skip_check: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}
if not skip_check: # check model
# check model provider_model = provider_model_bundle.configuration.get_provider_model(
provider_model = provider_model_bundle.configuration.get_provider_model( model=model_config.model, model_type=ModelType.LLM
model=model_config.model, model_type=ModelType.LLM )
)
if provider_model is None: if provider_model is None:
model_name = model_config.model model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE: if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION: elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
@ -77,7 +73,7 @@ class ModelConfigConverter:
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
if not skip_check and not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity( return ModelConfigWithCredentialsEntity(

View File

@ -1,4 +1,5 @@
from core.app.app_config.entities import ( from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity, AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity,
PromptTemplateEntity, PromptTemplateEntity,
@ -25,7 +26,9 @@ class PromptTemplateConfigManager:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append( chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
) )
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum): class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input" TEXT_INPUT = "text-input"
SELECT = "select" SELECT = "select"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"

View File

@ -138,7 +138,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -139,7 +139,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -1,5 +1,5 @@
import json import json
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
@ -7,7 +7,7 @@ from core.file import File, FileUploadConfig
from factories import file_factory from factories import file_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity from core.app.app_config.entities import VariableEntity
class BaseAppGenerator: class BaseAppGenerator:
@ -15,23 +15,23 @@ class BaseAppGenerator:
self, self,
*, *,
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig", variables: Sequence["VariableEntity"],
tenant_id: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = { user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables for var in variables
} }
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File # Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables} entity_dictionary = {item.variable: item for item in variables}
# Convert single file to File # Convert single file to File
files_inputs = { files_inputs = {
k: file_factory.build_from_mapping( k: file_factory.build_from_mapping(
mapping=v, mapping=v,
tenant_id=app_config.tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
@ -45,7 +45,7 @@ class BaseAppGenerator:
file_list_inputs = { file_list_inputs = {
k: file_factory.build_from_mappings( k: file_factory.build_from_mappings(
mappings=v, mappings=v,
tenant_id=app_config.tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,

View File

@ -142,7 +142,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -123,7 +123,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id
),
query=query, query=query,
files=file_objs, files=file_objs,
user_id=user.id, user_id=user.id,

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import and_ from sqlalchemy import and_
@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit() db.session.commit()
db.session.refresh(conversation) db.session.refresh(conversation)
else: else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
message = Message( message = Message(

View File

@ -108,7 +108,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files, files=system_files,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,

View File

@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import (
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner):
user_inputs=user_inputs, user_inputs=user_inputs,
variable_pool=variable_pool, variable_pool=variable_pool,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
) )
return graph, variable_pool return graph, variable_pool

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
class QueueEvent(str, Enum): class QueueEvent(StrEnum):
""" """
QueueEvent enum QueueEvent enum
""" """

View File

@ -1,7 +1,7 @@
import json import json
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -144,7 +144,7 @@ class WorkflowCycleManage:
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
db.session.refresh(workflow_run) db.session.refresh(workflow_run)
@ -191,7 +191,7 @@ class WorkflowCycleManage:
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
@ -211,15 +211,18 @@ class WorkflowCycleManage:
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.elapsed_time = ( workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds() ).total_seconds()
db.session.commit() db.session.commit()
db.session.refresh(workflow_run)
db.session.close() db.session.close()
with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
@ -259,7 +262,7 @@ class WorkflowCycleManage:
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
} }
) )
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) session.add(workflow_node_execution)
session.commit() session.commit()
@ -282,7 +285,7 @@ class WorkflowCycleManage:
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
) )
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
@ -326,7 +329,7 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@ -654,7 +657,7 @@ class WorkflowCycleManage:
if event.error is None if event.error is None
else WorkflowNodeExecutionStatus.FAILED, else WorkflowNodeExecutionStatus.FAILED,
error=None, error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata, execution_metadata=event.metadata,
finished_at=int(time.time()), finished_at=int(time.time()),

View File

@ -246,7 +246,7 @@ class ProviderConfiguration(BaseModel):
if provider_record: if provider_record:
provider_record.encrypted_config = json.dumps(credentials) provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True provider_record.is_valid = True
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_record = Provider() provider_record = Provider()
@ -401,7 +401,7 @@ class ProviderConfiguration(BaseModel):
if provider_model_record: if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True provider_model_record.is_valid = True
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_model_record = ProviderModel() provider_model_record = ProviderModel()
@ -474,7 +474,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.enabled = True model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()
@ -508,7 +508,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.enabled = False model_setting.enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()
@ -574,7 +574,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.load_balancing_enabled = True model_setting.load_balancing_enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()
@ -608,7 +608,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.load_balancing_enabled = False model_setting.load_balancing_enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class FileType(str, Enum): class FileType(StrEnum):
IMAGE = "image" IMAGE = "image"
DOCUMENT = "document" DOCUMENT = "document"
AUDIO = "audio" AUDIO = "audio"
@ -16,7 +16,7 @@ class FileType(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(str, Enum): class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url" REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file" LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file" TOOL_FILE = "tool_file"
@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(str, Enum): class FileBelongsTo(StrEnum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(str, Enum): class FileAttribute(StrEnum):
TYPE = "type" TYPE = "type"
SIZE = "size" SIZE = "size"
NAME = "name" NAME = "name"
@ -51,5 +51,5 @@ class FileAttribute(str, Enum):
EXTENSION = "extension" EXTENSION = "extension"
class ArrayFileAttribute(str, Enum): class ArrayFileAttribute(StrEnum):
LENGTH = "length" LENGTH = "length"

View File

@ -3,7 +3,12 @@ import base64
from configs import dify_config from configs import dify_config
from core.file import file_repository from core.file import file_repository
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
return file.remote_url return file.remote_url
case FileAttribute.EXTENSION: case FileAttribute.EXTENSION:
return file.extension return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content( def to_prompt_message_content(
f: File, f: File,
/, /,
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
): ):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.
Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f) data = _to_url(f)
else: else:
@ -65,7 +52,7 @@ def to_prompt_message_content(
return ImagePromptMessageContent(data=data, detail=image_detail_config) return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO: case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f) encoded_string = _get_encoded_string(f)
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
@ -74,9 +61,20 @@ def to_prompt_message_content(
data = _to_url(f) data = _to_url(f)
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _: case _:
raise ValueError("file type f.type is not supported") raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /): def download(f: File, /):
@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
case FileTransferMethod.REMOTE_URL: case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
content = response.content data = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
case FileTransferMethod.LOCAL_FILE: case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f) upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key) data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE: case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f) tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key) data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string encoded_string = base64.b64encode(data).decode("utf-8")
case _: return encoded_string
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _to_base64_data_string(f: File, /): def _to_base64_data_string(f: File, /):
@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
return f"data:{f.mime_type};base64,{encoded_string}" return f"data:{f.mime_type};base64,{encoded_string}"
def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")
def _to_url(f: File, /): def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None: if f.remote_url is None:

View File

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import StrEnum
from threading import Lock from threading import Lock
from typing import Any, Optional from typing import Any, Optional
@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel):
data: Data data: Data
class CodeLanguage(str, Enum): class CodeLanguage(StrEnum):
PYTHON3 = "python3" PYTHON3 = "python3"
JINJA2 = "jinja2" JINJA2 = "jinja2"
JAVASCRIPT = "javascript" JAVASCRIPT = "javascript"

View File

@ -86,7 +86,7 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except ObjectDeletedError: except ObjectDeletedError:
logging.warning("Document deleted, document id: {}".format(dataset_document.id)) logging.warning("Document deleted, document id: {}".format(dataset_document.id))
@ -94,7 +94,7 @@ class IndexingRunner:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument): def run_in_splitting_status(self, dataset_document: DatasetDocument):
@ -142,13 +142,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def run_in_indexing_status(self, dataset_document: DatasetDocument): def run_in_indexing_status(self, dataset_document: DatasetDocument):
@ -200,13 +200,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def indexing_estimate( def indexing_estimate(
@ -372,7 +372,7 @@ class IndexingRunner:
after_indexing_status="splitting", after_indexing_status="splitting",
extra_update_params={ extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
@ -464,7 +464,7 @@ class IndexingRunner:
doc_store.add_documents(documents) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
@ -479,7 +479,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
@ -680,7 +680,7 @@ class IndexingRunner:
after_indexing_status="completed", after_indexing_status="completed",
extra_update_params={ extra_update_params={
DatasetDocument.tokens: tokens, DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None, DatasetDocument.error: None,
}, },
@ -705,7 +705,7 @@ class IndexingRunner:
{ {
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.enabled: True, DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
) )
@ -738,7 +738,7 @@ class IndexingRunner:
{ {
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.enabled: True, DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
) )
@ -849,7 +849,7 @@ class IndexingRunner:
doc_store.add_documents(documents) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
@ -864,7 +864,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
pass pass

View File

@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager from core.file import file_manager
from core.file.models import FileType
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
@ -27,7 +27,7 @@ class TokenBufferMemory:
def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]: ) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
:param max_token_limit: max token limit :param max_token_limit: max token limit
@ -102,12 +102,11 @@ class TokenBufferMemory:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs: for file in file_objs:
if file.type in {FileType.IMAGE, FileType.AUDIO}: prompt_message = file_manager.to_prompt_message_content(
prompt_message = file_manager.to_prompt_message_content( file,
file, image_detail_config=detail,
image_detail_config=detail, )
) prompt_message_contents.append(prompt_message)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))

View File

@ -136,10 +136,10 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
): ):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:

View File

@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
from .message_entities import ( from .message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
AudioPromptMessageContent, AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
@ -37,4 +38,5 @@ __all__ = [
"LLMResultChunk", "LLMResultChunk",
"LLMResultChunkDelta", "LLMResultChunkDelta",
"AudioPromptMessageContent", "AudioPromptMessageContent",
"DocumentPromptMessageContent",
] ]

View File

@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from enum import Enum from collections.abc import Sequence
from typing import Optional from enum import Enum, StrEnum
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel):
function: PromptMessageTool function: PromptMessageTool
class PromptMessageContentType(Enum): class PromptMessageContentType(StrEnum):
""" """
Enum class for prompt message content type. Enum class for prompt message content type.
""" """
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video" VIDEO = "video"
DOCUMENT = "document"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -93,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent):
Model class for image prompt message content. Model class for image prompt message content.
""" """
class DETAIL(str, Enum): class DETAIL(StrEnum):
LOW = "low" LOW = "low"
HIGH = "high" HIGH = "high"
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):
""" """
Model class for prompt message. Model class for prompt message.
""" """
role: PromptMessageRole role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None name: Optional[str] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:

View File

@ -1,5 +1,5 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -82,9 +82,12 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call" STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
class DefaultParameterName(str, Enum): class DefaultParameterName(StrEnum):
""" """
Enum class for parameter template variable. Enum class for parameter template variable.
""" """

View File

@ -1,6 +1,6 @@
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Optional, Union
from pydantic import ConfigDict from pydantic import ConfigDict
@ -41,7 +41,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -96,7 +96,7 @@ class LargeLanguageModel(AIModel):
model_parameters=model_parameters, model_parameters=model_parameters,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=tools, tools=tools,
stop=stop, stop=list(stop) if stop else None,
stream=stream, stream=stream,
) )
@ -176,7 +176,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -318,7 +318,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -364,7 +364,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -411,7 +411,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -459,7 +459,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 277 KiB

View File

@ -1,15 +0,0 @@
<svg width="68" height="24" viewBox="0 0 68 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Gemini">
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M50.6875 4.37014C48.3498 4.59292 46.5349 6.41319 46.3337 8.72764C46.1446 6.44662 44.2677 4.56074 41.9805 4.3737C44.2762 4.1997 46.152 2.28299 46.3373 0C46.4882 2.28911 48.405 4.20047 50.6875 4.37014ZM15.4567 9.41141L13.9579 10.9076C9.92941 6.64892 2.69298 9.97287 3.17317 15.8112C3.22394 23.108 14.5012 24.4317 15.3628 16.8809H9.52096L9.50061 14.9149H17.3595C18.8163 23.1364 8.44367 27.0292 3.19453 21.238C0.847044 18.7556 0.363651 14.7682 1.83717 11.7212C4.1129 6.62089 11.6505 5.29845 15.4567 9.41141ZM45.5915 23.5989H47.6945C47.6944 22.9155 47.6945 22.2307 47.6946 21.5452V21.5325C47.6948 19.8907 47.695 18.2453 47.6924 16.6072C47.6914 15.9407 47.6161 15.2823 47.4024 14.647C46.4188 11.2828 41.4255 11.4067 39.8332 14.214C38.5637 11.4171 34.4009 11.5236 32.8538 14.0084L32.8082 13.9976V12.4806L32.4233 12.4804H32.4224C31.8687 12.4801 31.3324 12.4798 30.7949 12.4811V23.5848L32.8977 23.5672C32.8981 22.9411 32.8979 22.3122 32.8977 21.6822V21.6812V21.6802V21.6791V21.6781V21.6771V21.676V21.676V21.6759V21.6758V21.6757V21.6757V21.6756C32.8973 20.204 32.8969 18.7261 32.904 17.2614C32.8889 15.3646 34.5674 13.5687 36.5358 14.124C37.7794 14.3298 38.1851 15.6148 38.1761 16.7257C38.1821 17.7019 38.18 18.6824 38.178 19.6633V19.6638C38.1752 20.9756 38.1724 22.2881 38.1891 23.5919L40.2846 23.5731C40.2929 22.7511 40.2881 21.9245 40.2832 21.0966C40.2753 19.7402 40.2674 18.3805 40.317 17.0328C40.4418 15.2122 42.0141 13.6186 43.9064 14.1168C45.2685 14.3231 45.6136 15.7748 45.5882 16.9545C45.5938 18.4959 45.5929 20.0492 45.5921 21.5968V21.5991V21.6014V21.6037V21.606V21.6083V21.6106C45.5917 22.2749 45.5913 22.9382 45.5915 23.5989ZM20.6167 18.4408C20.5625 21.9486 25.2121 23.6996 27.2993 20.0558L29.1566 20.9592C27.8157 23.7067 24.2337 24.7424 21.5381 23.4213C18.0052 21.7253 17.41 16.5007 20.0334 13.7517C21.4609 12.1752 23.7291 11.7901 25.7206 12.3653C28.3408 13.1257 29.4974 15.8937 29.326 18.4399C27.5547 18.4415 25.7971 18.4412 24.0364 18.4409C22.8993 18.4407 21.7609 18.4405 20.6167 18.4408ZM27.1041 16.6957C26.7048 13.1033 21.2867 13.2256 20.7494 16.6957H27.1041ZM53.543 23.5999H55.6206L55.6206 22.4361C55.6205 20.7877 55.6205 19.1443 55.6207 17.4939C55.6208 16.8853 55.7234 16.297 56.0063 15.7531C56.6115 14.3862 58.1745 13.7002 59.5927 14.1774C60.7512 14.4455 61.2852 15.6069 61.2762 16.7154C61.2774 18.3497 61.2771 19.9826 61.2769 21.6162V21.6166V21.617V21.6174V21.6179L61.2766 23.6007H63.3698C63.3913 22.0924 63.3869 20.584 63.3826 19.0755V19.0754V19.0753V19.0753V19.0752C63.3799 18.1682 63.3773 17.2612 63.3803 16.3541C63.3796 15.8622 63.3103 15.3765 63.1698 14.9052C62.3248 11.5142 57.3558 11.2385 55.5828 14.0038L55.5336 13.9905V12.4917H53.539C53.4898 12.7313 53.4934 23.4113 53.543 23.5999ZM49.6211 12.4944H51.7065V23.5994H49.6211V12.4944ZM65.1035 23.5991H67.1831C67.2367 23.2198 67.2133 12.6566 67.1634 12.4983H65.1035V23.5991ZM52.1504 8.67829C52.1709 10.4847 49.2418 10.7058 49.1816 8.65714C49.2189 6.5948 52.2437 6.81331 52.1504 8.67829ZM66.1387 10.1324C64.2712 10.1609 64.1316 7.19881 66.1559 7.17114C68.1709 7.19817 68.0215 10.2087 66.1387 10.1324Z" fill="url(#paint0_linear_14286_118464)"/>
</g>
<defs>
<linearGradient id="paint0_linear_14286_118464" x1="-2" y1="0.999998" x2="67.9999" y2="27.5002" gradientUnits="userSpaceOnUse">
<stop stop-color="#7798E0"/>
<stop offset="0.210002" stop-color="#086FFF"/>
<stop offset="0.345945" stop-color="#086FFF"/>
<stop offset="0.591777" stop-color="#479AFF"/>
<stop offset="0.895892" stop-color="#B7C4FA"/>
<stop offset="1" stop-color="#B5C5F9"/>
</linearGradient>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

View File

@ -1,11 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="6" fill="url(#paint0_linear_7301_16076)"/>
<path d="M20 12.0116C15.7043 12.42 12.3692 15.757 11.9995 20C11.652 15.8183 8.20301 12.361 4 12.0181C8.21855 11.6991 11.6656 8.1853 12.006 4C12.2833 8.19653 15.8057 11.7005 20 12.0116Z" fill="white" fill-opacity="0.88"/>
<defs>
<linearGradient id="paint0_linear_7301_16076" x1="-9" y1="29.5" x2="19.4387" y2="1.43791" gradientUnits="userSpaceOnUse">
<stop offset="0.192878" stop-color="#1C7DFF"/>
<stop offset="0.520213" stop-color="#1C69FF"/>
<stop offset="1" stop-color="#F0DCD6"/>
</linearGradient>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 689 B

View File

@ -1,10 +0,0 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class GPUStackProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -1,120 +0,0 @@
provider: gpustack
label:
en_US: GPUStack
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: endpoint_url
label:
zh_Hans: 服务器地址
en_US: Server URL
type: text-input
required: true
placeholder:
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 输入您的 API Key
en_US: Enter your API Key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择补全类型
en_US: Select completion type
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: "8192"
placeholder:
zh_Hans: 输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens_to_sample
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: "8192"
type: text-input
- variable: function_calling_type
show_on:
- variable: __model_type
value: llm
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: function_call
label:
en_US: Function Call
zh_Hans: Function Call
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: vision_support
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: Vision 支持
en_US: Vision Support
type: select
required: false
default: no_support
options:
- value: support
label:
en_US: Support
zh_Hans: 支持
- value: no_support
label:
en_US: Not Support
zh_Hans: 不支持

View File

@ -1,45 +0,0 @@
from collections.abc import Generator
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
OAIAPICompatLargeLanguageModel,
)
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return super()._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"

View File

@ -1,146 +0,0 @@
from json import dumps
from typing import Optional
import httpx
from requests import post
from yarl import URL
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class GPUStackRerankModel(RerankModel):
"""
Model class for GPUStack rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
endpoint_url = credentials["endpoint_url"]
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
"Content-Type": "application/json",
}
data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
try:
response = post(
str(URL(endpoint_url) / "v1" / "rerank"),
headers=headers,
data=dumps(data),
timeout=10,
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results["results"]:
index = result["index"]
if "document" in result:
text = result["document"]["text"]
else:
text = docs[index]
rerank_document = RerankDocument(
index=index,
text=text,
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -1,35 +0,0 @@
from typing import Optional
from yarl import URL
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel,
)
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
"""
Model class for GPUStack text embedding model.
"""
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")

View File

@ -1,55 +0,0 @@
model: claude-3-5-sonnet-v2@20241022
label:
en_US: Claude 3.5 Sonnet v2
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 8192
min: 1
max: 8192
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

View File

@ -1,3 +0,0 @@
<svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/>
</svg>

Before

Width:  |  Height:  |  Size: 417 B

View File

@ -1,83 +0,0 @@
from decimal import Decimal
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceConfig,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
features = []
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features,
model_properties={
ModelPropertyKey.MODE: credentials.get("mode"),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
default=float(credentials.get("temperature", 0.7)),
min=0,
max=2,
precision=2,
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
default=float(credentials.get("top_p", 1)),
min=0,
max=1,
precision=2,
),
ParameterRule(
name=DefaultParameterName.TOP_K.value,
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(credentials.get("top_k", 50)),
min=-2147483647,
max=2147483647,
precision=0,
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens"),
type=ParameterType.INT,
default=512,
min=1,
max=int(credentials.get("max_tokens_to_sample", 4096)),
),
],
pricing=PriceConfig(
input=Decimal(credentials.get("input_price", 0)),
output=Decimal(credentials.get("output_price", 0)),
unit=Decimal(credentials.get("unit", 0)),
currency=credentials.get("currency", "USD"),
),
)
if credentials["mode"] == "chat":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials["mode"] == "completion":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity

View File

@ -1,10 +0,0 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class VesslAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -1,56 +0,0 @@
provider: vessl_ai
label:
en_US: VESSL AI
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#F1EFED"
help:
title:
en_US: How to deploy VESSL AI LLM Model Endpoint
url:
en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
supported_model_types:
- llm
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
placeholder:
en_US: Enter model name
credential_form_schemas:
- variable: endpoint_url
label:
en_US: Endpoint Url
type: text-input
required: true
placeholder:
en_US: Enter VESSL AI service endpoint url
- variable: api_key
required: true
label:
en_US: API Key
type: secret-input
placeholder:
en_US: Enter VESSL AI secret key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion Mode
type: select
required: false
default: chat
placeholder:
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
- value: chat
label:
en_US: Chat

View File

@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>

Before

Width:  |  Height:  |  Size: 356 B

View File

@ -1,63 +0,0 @@
model: grok-beta
label:
en_US: Grok beta
model_type: llm
features:
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -1,37 +0,0 @@
from collections.abc import Generator
from typing import Optional, Union
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
credentials["mode"] = LLMMode.CHAT.value
credentials["function_calling_type"] = "tool_call"

View File

@ -1,25 +0,0 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class XAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -1,38 +0,0 @@
provider: x
label:
en_US: xAI
description:
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
icon_small:
en_US: x-ai-logo.svg
icon_large:
en_US: x-ai-logo.svg
help:
title:
en_US: Get your token from xAI
zh_Hans: 从 xAI 获取 token
url:
en_US: https://x.ai/api
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
default: https://api.x.ai/v1
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
@ -122,7 +122,7 @@ trace_info_info_map = {
} }
class TraceTaskName(str, Enum): class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation" CONVERSATION_TRACE = "conversation"
WORKFLOW_TRACE = "workflow" WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message" MESSAGE_TRACE = "message"

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -39,7 +39,7 @@ def validate_input_output(v, field_name):
return v return v
class LevelEnum(str, Enum): class LevelEnum(StrEnum):
DEBUG = "DEBUG" DEBUG = "DEBUG"
WARNING = "WARNING" WARNING = "WARNING"
ERROR = "ERROR" ERROR = "ERROR"
@ -178,7 +178,7 @@ class LangfuseSpan(BaseModel):
return validate_input_output(v, field_name) return validate_input_output(v, field_name)
class UnitEnum(str, Enum): class UnitEnum(StrEnum):
CHARACTERS = "CHARACTERS" CHARACTERS = "CHARACTERS"
TOKENS = "TOKENS" TOKENS = "TOKENS"
SECONDS = "SECONDS" SECONDS = "SECONDS"

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -8,7 +8,7 @@ from pydantic_core.core_schema import ValidationInfo
from core.ops.utils import replace_text_with_content from core.ops.utils import replace_text_with_content
class LangSmithRunType(str, Enum): class LangSmithRunType(StrEnum):
tool = "tool" tool = "tool"
chain = "chain" chain = "chain"
llm = "llm" llm = "llm"

View File

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
class ModelMode(str, enum.Enum): class ModelMode(enum.StrEnum):
COMPLETION = "completion" COMPLETION = "completion"
CHAT = "chat" CHAT = "chat"

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import cast from typing import cast
from core.model_runtime.entities import ( from core.model_runtime.entities import (
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil: class PromptMessageUtil:
@staticmethod @staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
""" """
Prompt messages to prompt for saving. Prompt messages to prompt for saving.
:param model_mode: model mode :param model_mode: model mode

View File

@ -12,7 +12,7 @@ class CleanProcessor:
# Unicode U+FFFE # Unicode U+FFFE
text = re.sub("\ufffe", "", text) text = re.sub("\ufffe", "", text)
rules = process_rule["rules"] if process_rule else None rules = process_rule["rules"] if process_rule else {}
if "pre_processing_rules" in rules: if "pre_processing_rules" in rules:
pre_processing_rules = rules["pre_processing_rules"] pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules: for pre_processing_rule in pre_processing_rules:

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import StrEnum
class KeyWordType(str, Enum): class KeyWordType(StrEnum):
JIEBA = "jieba" JIEBA = "jieba"

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class VectorType(str, Enum): class VectorType(StrEnum):
ANALYTICDB = "analyticdb" ANALYTICDB = "analyticdb"
CHROMA = "chroma" CHROMA = "chroma"
MILVUS = "milvus" MILVUS = "milvus"

View File

@ -114,10 +114,10 @@ class WordExtractor(BaseExtractor):
mime_type=mime_type or "", mime_type=mime_type or "",
created_by=self.user_id, created_by=self.user_id,
created_by_role=CreatedByRole.ACCOUNT, created_by_role=CreatedByRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True, used=True,
used_by=self.user_id, used_by=self.user_id,
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
) )
db.session.add(upload_file) db.session.add(upload_file)

View File

@ -27,11 +27,11 @@ class RerankModelRunner(BaseRerankRunner):
:return: :return:
""" """
docs = [] docs = []
doc_id = set() doc_ids = set()
unique_documents = [] unique_documents = []
for document in documents: for document in documents:
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
doc_id.add(document.metadata["doc_id"]) doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content) docs.append(document.page_content)
unique_documents.append(document) unique_documents.append(document)
elif document.provider == "external": elif document.provider == "external":

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import StrEnum
class RerankMode(str, Enum): class RerankMode(StrEnum):
RERANKING_MODEL = "reranking_model" RERANKING_MODEL = "reranking_model"
WEIGHTED_SCORE = "weighted_score" WEIGHTED_SCORE = "weighted_score"

View File

@ -37,11 +37,10 @@ class WeightRerankRunner(BaseRerankRunner):
:return: :return:
""" """
unique_documents = [] unique_documents = []
doc_id = set() doc_ids = set()
for document in documents: for document in documents:
doc_id = document.metadata.get("doc_id") if document.metadata["doc_id"] not in doc_ids:
if doc_id not in doc_id: doc_ids.add(document.metadata["doc_id"])
doc_id.add(doc_id)
unique_documents.append(document) unique_documents.append(document)
documents = unique_documents documents = unique_documents

View File

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pytz import timezone as pytz_timezone from pytz import timezone as pytz_timezone
@ -23,7 +23,7 @@ class CurrentTimeTool(BuiltinTool):
tz = tool_parameters.get("timezone", "UTC") tz = tool_parameters.get("timezone", "UTC")
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
if tz == "UTC": if tz == "UTC":
return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") return self.create_text_message(f"{datetime.now(UTC).strftime(fm)}")
try: try:
tz = pytz_timezone(tz) tz = pytz_timezone(tz)

View File

@ -1,7 +1,7 @@
import json import json
from collections.abc import Generator, Iterable from collections.abc import Generator, Iterable
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timezone from datetime import UTC, datetime
from mimetypes import guess_type from mimetypes import guess_type
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -64,7 +64,12 @@ class ToolEngine:
if parameters and len(parameters) == 1: if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters} tool_parameters = {parameters[0].name: tool_parameters}
else: else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") try:
tool_parameters = json.loads(tool_parameters)
except Exception as e:
pass
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool # invoke the tool
try: try:
@ -195,10 +200,7 @@ class ToolEngine:
""" """
Invoke the tool with the given arguments. Invoke the tool with the given arguments.
""" """
if not tool.runtime: started_at = datetime.now(UTC)
raise ValueError("missing runtime in tool")
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta( meta = ToolInvokeMeta(
time_cost=0.0, time_cost=0.0,
error=None, error=None,
@ -216,7 +218,7 @@ class ToolEngine:
meta.error = str(e) meta.error = str(e)
raise ToolEngineInvokeError(meta) raise ToolEngineInvokeError(meta)
finally: finally:
ended_at = datetime.now(timezone.utc) ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds() meta.time_cost = (ended_at - started_at).total_seconds()
yield meta yield meta

View File

@ -118,11 +118,11 @@ class FileSegment(Segment):
@property @property
def log(self) -> str: def log(self) -> str:
return str(self.value) return ""
@property @property
def text(self) -> str: def text(self) -> str:
return str(self.value) return ""
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
for item in self.value: for item in self.value:
items.append(item.markdown) items.append(item.markdown)
return "\n".join(items) return "\n".join(items)
@property
def log(self) -> str:
return ""
@property
def text(self) -> str:
return ""

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class SegmentType(str, Enum): class SegmentType(StrEnum):
NONE = "none" NONE = "none"
NUMBER = "number" NUMBER = "number"
STRING = "string" STRING = "string"

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
class NodeRunMetadataKey(str, Enum): class NodeRunMetadataKey(StrEnum):
""" """
Node Run Metadata Key. Node Run Metadata Key.
""" """
@ -36,7 +36,7 @@ class NodeRunResult(BaseModel):
inputs: Optional[Mapping[str, Any]] = None # node inputs inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage llm_usage: Optional[LLMUsage] = None # llm usage

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class SystemVariableKey(str, Enum): class SystemVariableKey(StrEnum):
""" """
System Variables. System Variables.
""" """

View File

@ -1,5 +1,5 @@
import uuid import uuid
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -63,7 +63,7 @@ class RouteNodeState(BaseModel):
raise Exception(f"Invalid route status {run_result.status}") raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) self.finished_at = datetime.now(UTC).replace(tzinfo=None)
class RuntimeRouteState(BaseModel): class RuntimeRouteState(BaseModel):
@ -81,7 +81,7 @@ class RuntimeRouteState(BaseModel):
:param node_id: node id :param node_id: node id
""" """
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
self.node_state_mapping[state.id] = state self.node_state_mapping[state.id] = state
return state return state

Some files were not shown because too many files have changed in this diff Show More