mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-04-10 11:22:44 +08:00
Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
This commit is contained in:
parent
222df44d21
commit
403e2d58b9
@ -1,11 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
npm add -g pnpm@9.12.2
|
||||
cd web && pnpm install
|
||||
pipx install poetry
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
|
||||
|
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -50,15 +50,9 @@ jobs:
|
||||
- name: Run Unit tests
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -P api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
|
1
.github/workflows/db-migration-test.yml
vendored
1
.github/workflows/db-migration-test.yml
vendored
@ -4,6 +4,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
|
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@ -72,17 +72,23 @@ jobs:
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: yarn
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
6
.github/workflows/tool-test-sdks.yaml
vendored
6
.github/workflows/tool-test-sdks.yaml
vendored
@ -35,10 +35,10 @@ jobs:
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
run: pnpm install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
run: pnpm test
|
||||
|
@ -39,11 +39,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
run: pnpm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -37,13 +37,13 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: yarn
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn test
|
||||
run: pnpm test
|
||||
|
7
.gitignore
vendored
7
.gitignore
vendored
@ -176,6 +176,7 @@ docker/volumes/pgvector/data/*
|
||||
docker/volumes/pgvecto_rs/data/*
|
||||
docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@ -194,3 +195,9 @@ api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
@ -1,7 +1,10 @@
|
||||
.env
|
||||
*.env.*
|
||||
|
||||
storage/generate_files/*
|
||||
storage/privkeys/*
|
||||
storage/tools/*
|
||||
storage/upload_files/*
|
||||
|
||||
# Logs
|
||||
logs
|
||||
@ -9,6 +12,8 @@ logs
|
||||
|
||||
# jetbrains
|
||||
.idea
|
||||
.mypy_cache
|
||||
.ruff_cache
|
||||
|
||||
# venv
|
||||
.venv
|
@ -409,7 +409,6 @@ MAX_VARIABLE_SIZE=204800
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
@ -422,6 +421,22 @@ POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
|
@ -73,6 +73,10 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
|
||||
|
@ -25,6 +25,8 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@ -524,7 +526,7 @@ def add_qdrant_doc_id_index(field: str):
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
|
||||
|
||||
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
|
||||
@ -593,7 +595,7 @@ def upgrade_db():
|
||||
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed to execute database migration")
|
||||
finally:
|
||||
lock.release()
|
||||
@ -639,7 +641,7 @@ where sites.id is null limit 1000"""
|
||||
account = accounts[0]
|
||||
print("Fixing missing site for app {}".format(app.id))
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
|
||||
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
|
||||
@ -649,3 +651,68 @@ where sites.id is null limit 1000"""
|
||||
break
|
||||
|
||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
|
||||
|
||||
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
|
||||
def migrate_data_for_plugin():
|
||||
"""
|
||||
Migrate data for plugin.
|
||||
"""
|
||||
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
|
||||
|
||||
PluginDataMigration.migrate()
|
||||
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-plugins", help="Extract plugins.")
|
||||
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
|
||||
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
|
||||
def extract_plugins(output_file: str, workers: int):
|
||||
"""
|
||||
Extract plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_plugins(output_file, workers)
|
||||
|
||||
click.echo(click.style("Extract plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
|
||||
@click.option(
|
||||
"--output_file",
|
||||
prompt=True,
|
||||
help="The file to store the extracted unique identifiers.",
|
||||
default="unique_identifiers.json",
|
||||
)
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
def extract_unique_plugins(output_file: str, input_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract unique plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Extract unique plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-plugins", help="Install plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
def install_plugins(input_file: str, output_file: str):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting install plugins.", fg="white"))
|
||||
|
||||
PluginMigration.install_plugins(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Install plugins completed.", fg="green"))
|
||||
|
@ -134,6 +134,60 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
"""
|
||||
|
||||
PLUGIN_DAEMON_URL: HttpUrl = Field(
|
||||
description="Plugin API URL",
|
||||
default="http://localhost:5002",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_KEY: str = Field(
|
||||
description="Plugin API key",
|
||||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
description="Plugin Remote Install Host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
|
||||
description="Plugin Remote Install Port",
|
||||
default=5003,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for plugin packages in bytes",
|
||||
default=15728640,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for plugin bundles in bytes",
|
||||
default=15728640 * 12,
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
"""
|
||||
|
||||
MARKETPLACE_ENABLED: bool = Field(
|
||||
description="Enable or disable marketplace",
|
||||
default=True,
|
||||
)
|
||||
|
||||
MARKETPLACE_API_URL: HttpUrl = Field(
|
||||
description="Marketplace API URL",
|
||||
default="https://marketplace.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -160,6 +214,10 @@ class EndpointConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
ENDPOINT_URL_TEMPLATE: str = Field(
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -793,6 +851,8 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.15.3",
|
||||
default="1.0.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -1,9 +1,19 @@
|
||||
from contextvars import ContextVar
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
|
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportConfirmApi
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from .explore.conversation import (
|
||||
@ -40,6 +40,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
@ -166,4 +167,15 @@ api.add_resource(
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
workspace,
|
||||
)
|
||||
|
@ -2,6 +2,8 @@ from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource):
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
installed_apps = InstalledApp.query.filter(
|
||||
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
).all()
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).filter(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
|
@ -3,6 +3,8 @@ from typing import Any
|
||||
import flask_restful # type: ignore
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -26,7 +28,16 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
|
@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_fields
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
|
||||
|
||||
@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
@ -2,6 +2,7 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@ -50,33 +51,37 @@ class AppSite(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
|
||||
with Session(db.engine) as session:
|
||||
site = session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
@ -20,6 +20,7 @@ from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
@ -96,6 +97,9 @@ class DraftWorkflowApi(Resource):
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -139,6 +143,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
@ -160,7 +167,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -178,38 +185,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -228,7 +204,44 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -246,6 +259,9 @@ class DraftWorkflowRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
@ -294,13 +310,20 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
@ -339,6 +362,9 @@ class PublishedWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
|
||||
@ -376,12 +402,17 @@ class DefaultBlockConfigApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
filters = None
|
||||
if args.get("q"):
|
||||
if q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
except json.JSONDecodeError:
|
||||
@ -407,6 +438,9 @@ class ConvertToWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
|
@ -3,6 +3,8 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
@ -43,7 +45,8 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = Account.query.filter_by(email=args["email"]).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
@ -116,7 +119,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
|
||||
if account:
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
@ -137,7 +141,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
@ -5,6 +5,8 @@ from typing import Optional
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
account = Account.query.filter_by(email=user_info.email).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
|
||||
return account
|
||||
|
||||
|
@ -4,6 +4,8 @@ import json
|
||||
from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@ -76,7 +78,10 @@ class DataSourceApi(Resource):
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
@ -108,47 +113,53 @@ class DataSourceNotionListApi(Resource):
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
exist_page_ids = []
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
|
||||
documents = session.execute(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
@ -158,14 +169,17 @@ class DataSourceNotionApi(Resource):
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
|
@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, enterpris
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -72,7 +73,9 @@ class DatasetListApi(Resource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality":
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
|
@ -7,7 +7,6 @@ from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import asc, desc
|
||||
from transformers.hf_argparser import string_to_bool # type: ignore
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -40,6 +39,7 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -150,8 +150,20 @@ class DatasetDocumentListApi(Resource):
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch = string_to_bool(request.args.get("fetch", default="false"))
|
||||
except (ArgumentTypeError, ValueError, Exception) as e:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
|
||||
fetch = True
|
||||
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
|
||||
fetch = False
|
||||
else:
|
||||
raise ArgumentTypeError(
|
||||
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
|
||||
f"(case insensitive)."
|
||||
)
|
||||
except (ArgumentTypeError, ValueError, Exception):
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -429,6 +441,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except PluginDaemonClientSideError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
@ -529,6 +543,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except PluginDaemonClientSideError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
|
@ -2,8 +2,11 @@ import os
|
||||
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
|
||||
def get_init_validate_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
return session.get("is_init_validated") or DifySetup.query.first()
|
||||
if session.get("is_init_validated"):
|
||||
return True
|
||||
|
||||
with Session(db.engine) as db_session:
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
|
||||
return True
|
||||
|
||||
|
@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
@ -52,8 +52,9 @@ class SetupApi(Resource):
|
||||
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return DifySetup.query.first()
|
||||
return True
|
||||
return db.session.query(DifySetup).first()
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(SetupApi, "/setup")
|
||||
|
@ -0,0 +1,56 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission)
|
||||
.filter(
|
||||
TenantPluginPermission.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission:
|
||||
# no permission set, allow access for everyone
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
36
api/controllers/console/workspace/agent_providers.py
Normal file
36
api/controllers/console/workspace/agent_providers.py
Normal file
@ -0,0 +1,36 @@
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
|
||||
|
||||
class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
user = current_user
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
||||
|
||||
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
|
||||
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")
|
205
api/controllers/console/workspace/endpoint.py
Normal file
205
api/controllers/console/workspace/endpoint.py
Normal file
@ -0,0 +1,205 @@
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
class EndpointCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
parser.add_argument("settings", type=dict, required=True)
|
||||
parser.add_argument("name", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class EndpointDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
parser.add_argument("settings", type=dict, required=True)
|
||||
parser.add_argument("name", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
|
||||
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
|
||||
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
|
||||
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
|
||||
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
|
||||
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
|
||||
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")
|
@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
# Load Balancing Config
|
||||
api.add_resource(
|
||||
LoadBalancingCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
LoadBalancingConfigCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
)
|
||||
|
@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource):
|
||||
response = {"result": "success" if result else "error"}
|
||||
|
||||
if not result:
|
||||
response["error"] = error
|
||||
response["error"] = error or "Unknown error"
|
||||
|
||||
return response
|
||||
|
||||
@ -125,9 +125,10 @@ class ModelProviderIconApi(Resource):
|
||||
Get model provider icon
|
||||
"""
|
||||
|
||||
def get(self, provider: str, icon_type: str, lang: str):
|
||||
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
icon_type=icon_type,
|
||||
lang=lang,
|
||||
@ -183,53 +184,17 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
return data
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
|
||||
)
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
)
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
api.add_resource(
|
||||
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
|
||||
ModelProviderIconApi,
|
||||
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
|
||||
)
|
||||
|
@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource):
|
||||
response = {"result": "success" if result else "error"}
|
||||
|
||||
if not result:
|
||||
response["error"] = error
|
||||
response["error"] = error or ""
|
||||
|
||||
return response
|
||||
|
||||
@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
|
||||
api.add_resource(
|
||||
ModelProviderModelEnableApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/enable",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable",
|
||||
endpoint="model-provider-model-enable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelDisableApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/disable",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable",
|
||||
endpoint="model-provider-model-disable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
|
||||
)
|
||||
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
||||
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")
|
||||
|
475
api/controllers/console/workspace/plugin.py
Normal file
475
api/controllers/console/workspace/plugin.py
Normal file
@ -0,0 +1,475 @@
|
||||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from models.account import TenantPluginPermission
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
|
||||
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
|
||||
}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
try:
|
||||
plugins = PluginService.list(tenant_id)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
class PluginIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
req.add_argument("filename", type=str, required=True, location="args")
|
||||
args = req.parse_args()
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
file = request.files["pkg"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_pkg(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
file = request.files["bundle"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_bundle(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_plugin_manifest(
|
||||
tenant_id, args["plugin_unique_identifier"]
|
||||
).model_dump()
|
||||
}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_github(
|
||||
tenant_id,
|
||||
args["original_plugin_unique_identifier"],
|
||||
args["new_plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("install_permission", type=str, required=True, location="json")
|
||||
req.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
||||
class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"install_permission": permission.install_permission,
|
||||
"debug_permission": permission.debug_permission,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
|
||||
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
|
||||
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
|
||||
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
|
||||
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
|
||||
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
|
||||
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
|
||||
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
|
||||
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
|
||||
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
|
||||
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
|
||||
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
|
||||
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
|
||||
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
@ -25,8 +25,10 @@ class ToolProviderListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument(
|
||||
@ -47,28 +49,43 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
|
||||
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
@ -82,11 +99,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -131,11 +150,13 @@ class ToolApiProviderAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -168,6 +189,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
@ -175,8 +201,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
user_id,
|
||||
tenant_id,
|
||||
args["url"],
|
||||
)
|
||||
|
||||
@ -186,8 +212,10 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -209,11 +237,13 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -248,11 +278,13 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -272,8 +304,10 @@ class ToolApiProviderGetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -293,7 +327,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@ -344,11 +382,13 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -381,11 +421,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -421,11 +463,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -444,8 +488,10 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
@ -476,8 +522,10 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||
@ -498,8 +546,10 @@ class ToolBuiltinListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -517,8 +567,10 @@ class ToolApiListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -536,8 +588,10 @@ class ToolWorkflowListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -563,16 +617,18 @@ class ToolLabelsApi(Resource):
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
# api tool provider
|
||||
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||
|
@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
@ -134,9 +135,13 @@ def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
@ -6,4 +6,4 @@ bp = Blueprint("files", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview, tool_files
|
||||
from . import image_preview, tool_files, upload
|
||||
|
69
api/controllers/files/upload.py
Normal file
69
api/controllers/files/upload.py
Normal file
@ -0,0 +1,69 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from controllers.service_api.app.error import FileTooLargeError
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class PluginUploadFileApi(Resource):
|
||||
@setup_required
|
||||
@marshal_with(file_fields)
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
tenant_id = request.args.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
user = get_user(tenant_id, user_id)
|
||||
|
||||
filename = file.filename
|
||||
mimetype = file.mimetype
|
||||
|
||||
if not filename or not mimetype:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
if not verify_plugin_file_signature(
|
||||
filename=filename,
|
||||
mimetype=mimetype,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=file.read(),
|
||||
mimetype=mimetype,
|
||||
user=user,
|
||||
source=None,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
@ -5,4 +5,5 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
|
293
api/controllers/inner_api/plugin/plugin.py
Normal file
293
api/controllers/inner_api/plugin/plugin.py
Normal file
@ -0,0 +1,293 @@
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeParameterExtractorNode,
|
||||
RequestInvokeQuestionClassifierNode,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
RequestInvokeSummary,
|
||||
RequestInvokeTextEmbedding,
|
||||
RequestInvokeTool,
|
||||
RequestInvokeTTS,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import compact_generate_response
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_text_embedding(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeRerank)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_rerank(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeTTSApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTTS)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_tts(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_speech2text(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeModerationApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeModeration)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_moderation(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
),
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
parameters=payload.parameters,
|
||||
model_config=payload.model,
|
||||
instruction=payload.instruction,
|
||||
query=payload.query,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
query=payload.query,
|
||||
model_config=payload.model,
|
||||
classes=payload.classes,
|
||||
instruction=payload.instruction,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||
"""
|
||||
encrypt or decrypt data
|
||||
"""
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
|
||||
).model_dump()
|
||||
except Exception as e:
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
class PluginInvokeSummaryApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSummary)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data={
|
||||
"summary": PluginModelBackwardsInvocation.invoke_summary(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
}
|
||||
).model_dump()
|
||||
except Exception as e:
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
class PluginUploadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
116
api/controllers/inner_api/plugin/wraps.py
Normal file
116
api/controllers/inner_api/plugin/wraps.py
Normal file
@ -0,0 +1,116 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
if user_id == "DEFAULT-USER":
|
||||
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user_model:
|
||||
raise ValueError("user not found")
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant(view: Optional[Callable] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
# fetch json body
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||
parser.add_argument("user_id", type=str, required=True, location="json")
|
||||
|
||||
kwargs = parser.parse_args()
|
||||
|
||||
user_id = kwargs.get("user_id")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
del kwargs["tenant_id"]
|
||||
del kwargs["user_id"]
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.filter(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func):
|
||||
def decorated_view(*args, **kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type(**data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
@ -12,7 +12,7 @@ from services.account_service import TenantService
|
||||
|
||||
class EnterpriseWorkspace(Resource):
|
||||
@setup_required
|
||||
@inner_api_only
|
||||
@enterprise_inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
@ -33,7 +33,7 @@ class EnterpriseWorkspace(Resource):
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
@setup_required
|
||||
@inner_api_only
|
||||
@enterprise_inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
|
@ -10,7 +10,7 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def inner_api_only(view):
|
||||
def enterprise_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
@ -18,7 +18,7 @@ def inner_api_only(view):
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
abort(401)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@ -26,7 +26,7 @@ def inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def inner_api_user_auth(view):
|
||||
def enterprise_inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
@ -60,3 +60,19 @@ def inner_api_user_auth(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -32,19 +31,16 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -62,11 +58,9 @@ class BaseAgentRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@ -79,8 +73,6 @@ class BaseAgentRunner(AppRunner):
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
@ -141,11 +133,10 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
assert tool_entity.entity.description
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm if tool_entity.description else "",
|
||||
description=tool_entity.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -153,7 +144,7 @@ class BaseAgentRunner(AppRunner):
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
@ -186,9 +177,11 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
assert tool.entity.description
|
||||
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name if tool.identity else "unknown",
|
||||
description=tool.description.llm if tool.description else "",
|
||||
name=tool.entity.identity.name,
|
||||
description=tool.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -234,8 +227,7 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
if dataset_tool.identity is not None:
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@ -320,24 +312,24 @@ class BaseAgentRunner(AppRunner):
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
tool_name: str | None,
|
||||
tool_input: Union[str, dict, None],
|
||||
thought: str | None,
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str,
|
||||
answer: str | None,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
queried_thought = (
|
||||
updated_agent_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not queried_thought:
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
agent_thought = updated_agent_thought
|
||||
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
@ -349,39 +341,39 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
updated_agent_thought.tool_input = tool_input
|
||||
|
||||
if observation:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
updated_agent_thought.observation = observation
|
||||
|
||||
if answer:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
||||
updated_agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
labels = updated_agent_thought.tool_labels or {}
|
||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@ -392,42 +384,20 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
try:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
queried_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not queried_variables:
|
||||
return
|
||||
|
||||
db_variables = queried_variables
|
||||
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
@ -464,11 +434,11 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
|
||||
for tool in tools:
|
||||
@ -515,7 +485,11 @@ class BaseAgentRunner(AppRunner):
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
if message.app_model_config:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
else:
|
||||
file_extra_config = None
|
||||
|
||||
if not file_extra_config:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
@ -27,11 +27,11 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] | None = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
||||
_instruction: str = "" # FIXME this must be str for now
|
||||
_query: str | None = None
|
||||
_prompt_messages_tools: list[PromptMessageTool] = []
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
@ -42,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
@ -54,17 +55,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
@ -116,14 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
if not isinstance(chunks, Generator):
|
||||
raise ValueError("Expected streaming response from LLM")
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
@ -143,25 +139,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += chunk
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought += chunk
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
if self._agent_scratchpad is not None:
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
@ -256,8 +252,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
if self.variables_pool is not None and self.db_variables_pool is not None:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -275,7 +269,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
@ -315,11 +309,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as is not None and self.variables_pool:
|
||||
# FIXME the save_as type is confusing, it should be a string or not
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
||||
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@ -342,7 +332,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
@ -379,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: Optional[list[PromptMessage]] = None
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
@ -391,8 +381,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
if not isinstance(message.content, str | None):
|
||||
raise NotImplementedError("expected str type")
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@ -411,9 +400,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if not current_scratchpad:
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
|
@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
if not self.app_config.agent:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
@ -83,8 +83,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
|
@ -1,18 +1,21 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
@ -66,7 +69,7 @@ class AgentEntity(BaseModel):
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
class Strategy(StrEnum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
@ -78,5 +81,13 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] | None = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
Agent Invoke Message.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
@ -46,18 +46,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -107,7 +109,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@ -124,7 +126,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@ -140,7 +142,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
else:
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
@ -151,7 +153,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@ -183,8 +185,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@ -243,15 +243,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@ -303,8 +300,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -335,9 +330,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -360,7 +353,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
@ -383,9 +376,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
|
||||
) -> list[PromptMessage]:
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
|
89
api/core/agent/plugin_entities.py
Normal file
89
api/core/agent/plugin_entities.py
Normal file
@ -0,0 +1,89 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
"""
|
||||
Inherits from ToolProviderIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class AgentStrategyProviderEntity(BaseModel):
|
||||
identity: AgentStrategyProviderIdentity
|
||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||
|
||||
|
||||
class AgentStrategyIdentity(ToolIdentity):
|
||||
"""
|
||||
Inherits from ToolIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
42
api/core/agent/strategy/base.py
Normal file
42
api/core/agent/strategy/base.py
Normal file
@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
|
||||
|
||||
class BaseAgentStrategy(ABC):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
"""
|
||||
Get the parameters for the agent strategy.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
pass
|
59
api/core/agent/strategy/plugin.py
Normal file
59
api/core/agent/strategy/plugin.py
Normal file
@ -0,0 +1,59 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class PluginAgentStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
declaration: AgentStrategyEntity
|
||||
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||
self.tenant_id = tenant_id
|
||||
self.declaration = declaration
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Initialize the parameters for the agent strategy.
|
||||
"""
|
||||
for parameter in self.declaration.parameters:
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
|
||||
initialized_params = self.initialize_parameters(params)
|
||||
params = convert_parameters_to_plugin_format(initialized_params)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
agent_provider=self.declaration.identity.provider,
|
||||
agent_strategy=self.declaration.identity.name,
|
||||
agent_params=params,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
@ -4,7 +4,8 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
@ -63,14 +64,14 @@ class ModelConfigConverter:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
@ -2,8 +2,9 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@ -53,9 +54,18 @@ class ModelConfigManager:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
|
@ -37,17 +37,6 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -65,20 +54,31 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -154,6 +154,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -165,8 +167,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -203,6 +211,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -222,7 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -56,7 +56,7 @@ def _process_future(
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher:
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
self.voices = self.model_instance.get_tts_voices(language=language)
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
|
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
@ -58,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -84,12 +83,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -123,4 +122,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
@ -219,7 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
tts_publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
|
||||
)
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@ -247,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -640,6 +643,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
session.commit()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@ -29,17 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -51,6 +41,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -60,7 +61,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -70,7 +71,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -180,6 +181,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": contextvars.copy_context(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@ -204,6 +206,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
@ -218,6 +221,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
|
@ -8,18 +8,16 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -64,8 +62,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query,
|
||||
)
|
||||
|
||||
@ -86,8 +84,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
@ -99,8 +97,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
inputs=dict(inputs),
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
@ -156,9 +154,9 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -173,16 +171,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
return
|
||||
|
||||
agent_entity = app_config.agent
|
||||
if not agent_entity:
|
||||
raise ValueError("Agent entity not found")
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
assert agent_entity is not None
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
@ -193,9 +182,9 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -243,8 +232,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -261,73 +248,3 @@ class AgentChatAppRunner(AppRunner):
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables | None = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,10 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -80,13 +79,12 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -118,4 +116,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
@ -14,21 +14,15 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls,
|
||||
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
|
||||
invoke_from: InvokeFrom,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f"data: {chunk}\n\n"
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
@ -36,12 +30,8 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f"data: {chunk}\n\n"
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@ -59,14 +49,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.file import File, FileUploadConfig
|
||||
@ -138,3 +139,21 @@ class BaseAppGenerator:
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
|
||||
"""
|
||||
Convert messages into event stream
|
||||
"""
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, (Mapping, dict)):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
return gen()
|
||||
|
@ -2,7 +2,7 @@ import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
@ -115,7 +115,7 @@ class AppQueueManager:
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
|
@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -52,9 +52,8 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -80,13 +79,12 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -118,4 +116,4 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
) -> Generator[str | Mapping[str, Any], None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -56,8 +56,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,9 +51,8 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -78,13 +77,12 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls,
|
||||
stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -115,4 +113,4 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
@ -36,13 +36,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -50,12 +50,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -64,26 +64,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
@ -124,7 +124,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@ -146,7 +149,18 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@ -185,10 +199,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -224,6 +238,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
|
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
@ -36,9 +36,8 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -62,13 +61,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls,
|
||||
stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override]
|
||||
) -> Generator[str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -94,4 +92,4 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
@ -190,7 +191,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
tts_publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
|
||||
)
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@ -527,6 +530,10 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@ -27,6 +28,7 @@ from core.app.entities.queue_entities import (
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
@ -239,6 +241,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
@ -373,6 +376,19 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
|
@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
@ -41,6 +41,7 @@ class QueueEvent(StrEnum):
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@ -280,6 +281,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
@ -315,6 +317,22 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentLogEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_LOG
|
||||
id: str
|
||||
label: str
|
||||
node_execution_id: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -60,6 +61,7 @@ class StreamEvent(Enum):
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@ -247,6 +249,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@ -696,3 +699,26 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentLogStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
label: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
@ -24,6 +24,8 @@ class HostingModerationFeature:
|
||||
if isinstance(prompt_message.content, str):
|
||||
text += prompt_message.content + "\n"
|
||||
|
||||
moderation_result = moderation.check_moderation(model_config, text)
|
||||
moderation_result = moderation.check_moderation(
|
||||
tenant_id=application_generate_entity.app_config.tenant_id, model_config=model_config, text=text
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
@ -215,7 +215,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@ -24,6 +25,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
@ -320,9 +322,8 @@ class WorkflowCycleManage:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
execution_metadata_dict = dict(event.execution_metadata or {})
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
@ -540,6 +541,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
@ -843,3 +845,24 @@ class WorkflowCycleManage:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return session.merge(cached_workflow_node_execution)
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
Handle agent log
|
||||
:param task_id: task id
|
||||
:param event: agent log event
|
||||
:return:
|
||||
"""
|
||||
return AgentLogStreamResponse(
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
),
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, Optional, TextIO, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -57,7 +57,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Sequence[ToolInvokeMessage] | str,
|
||||
tool_outputs: Iterable[ToolInvokeMessage] | str,
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
|
@ -1,5 +1,26 @@
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def on_tool_execution(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Mapping[str, Any],
|
||||
tool_outputs: Iterable[ToolInvokeMessage],
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
for tool_output in tool_outputs:
|
||||
print_text("\n[on_tool_execution]\n", color=self.color)
|
||||
print_text("Tool: " + tool_name + "\n", color=self.color)
|
||||
print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
yield tool_output
|
||||
|
@ -0,0 +1 @@
|
||||
DEFAULT_PLUGIN_ID = "langgenius"
|
42
api/core/entities/parameter_entities.py
Normal file
42
api/core/entities/parameter_entities.py
Normal file
@ -0,0 +1,42 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CommonParameterType(StrEnum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
SYSTEM_FILES = "system-files"
|
||||
BOOLEAN = "boolean"
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
|
||||
|
||||
class AppSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CHAT = "chat"
|
||||
WORKFLOW = "workflow"
|
||||
COMPLETION = "completion"
|
||||
|
||||
|
||||
class ModelSelectorScope(StrEnum):
|
||||
LLM = "llm"
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
TTS = "tts"
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
class ToolSelectorScope(StrEnum):
|
||||
ALL = "all"
|
||||
CUSTOM = "custom"
|
||||
BUILTIN = "builtin"
|
||||
WORKFLOW = "workflow"
|
@ -2,13 +2,14 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@ -18,16 +19,15 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
@ -99,9 +99,10 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
if self.system_configuration.credentials is None:
|
||||
return None
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
|
||||
copy_credentials = (
|
||||
self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
|
||||
)
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
if (
|
||||
@ -140,6 +141,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if current_quota_configuration is None:
|
||||
return None
|
||||
|
||||
if not current_quota_configuration:
|
||||
return SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
return (
|
||||
SystemConfigurationStatus.ACTIVE
|
||||
if current_quota_configuration.is_valid
|
||||
@ -153,7 +157,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False):
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@ -175,7 +179,7 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -219,6 +223,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
@ -246,13 +251,13 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True,
|
||||
)
|
||||
provider_record = Provider()
|
||||
provider_record.tenant_id = self.tenant_id
|
||||
provider_record.provider_name = self.provider.provider
|
||||
provider_record.provider_type = ProviderType.CUSTOM.value
|
||||
provider_record.encrypted_config = json.dumps(credentials)
|
||||
provider_record.is_valid = True
|
||||
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
@ -327,7 +332,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def custom_model_credentials_validate(
|
||||
self, model_type: ModelType, model: str, credentials: dict
|
||||
) -> tuple[Optional[ProviderModel], dict]:
|
||||
) -> tuple[ProviderModel | None, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@ -370,6 +375,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
@ -400,14 +406,13 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True,
|
||||
)
|
||||
provider_model_record = ProviderModel()
|
||||
provider_model_record.tenant_id = self.tenant_id
|
||||
provider_model_record.provider_name = self.provider.provider
|
||||
provider_model_record.model_name = model
|
||||
provider_model_record.model_type = model_type.to_origin_model_type()
|
||||
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||
provider_model_record.is_valid = True
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
@ -474,13 +479,12 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=True,
|
||||
)
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.enabled = True
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
@ -509,13 +513,12 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=False,
|
||||
)
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.enabled = False
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
@ -576,13 +579,12 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.load_balancing_enabled = True
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
@ -611,25 +613,17 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
model_setting = ProviderModelSetting()
|
||||
model_setting.tenant_id = self.tenant_id
|
||||
model_setting.provider_name = self.provider.provider
|
||||
model_setting.model_type = model_type.to_origin_model_type()
|
||||
model_setting.model_name = model
|
||||
model_setting.load_balancing_enabled = False
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
:return:
|
||||
"""
|
||||
return model_provider_factory.get_provider_instance(self.provider.provider)
|
||||
|
||||
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
|
||||
"""
|
||||
Get current model type instance.
|
||||
@ -637,11 +631,19 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get provider instance
|
||||
provider_instance = self.get_provider_instance()
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
return provider_instance.get_model_instance(model_type)
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
||||
"""
|
||||
@ -668,11 +670,10 @@ class ProviderConfiguration(BaseModel):
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value,
|
||||
)
|
||||
preferred_model_provider = TenantPreferredModelProvider()
|
||||
preferred_model_provider.tenant_id = self.tenant_id
|
||||
preferred_model_provider.provider_name = self.provider.provider
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
db.session.commit()
|
||||
@ -737,13 +738,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param only_active: only active models
|
||||
:return:
|
||||
"""
|
||||
provider_instance = self.get_provider_instance()
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
|
||||
model_types = []
|
||||
model_types: list[ModelType] = []
|
||||
if model_type:
|
||||
model_types.append(model_type)
|
||||
else:
|
||||
model_types = list(provider_instance.get_provider_schema().supported_model_types)
|
||||
model_types = list(provider_schema.supported_model_types)
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
|
||||
@ -752,11 +754,11 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
||||
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
||||
)
|
||||
|
||||
if only_active:
|
||||
@ -767,23 +769,26 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_system_provider_models(
|
||||
self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_types: Sequence[ModelType],
|
||||
provider_schema: ProviderEntity,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:param provider_schema: provider schema
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
for model_type in model_types:
|
||||
for m in provider_instance.models(model_type):
|
||||
for m in provider_schema.models:
|
||||
if m.model_type != model_type:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
if m.model in model_setting_map:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
@ -804,7 +809,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
||||
for configurate_method in provider_schema.configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
should_use_custom_model = False
|
||||
@ -825,18 +830,22 @@ class ProviderConfiguration(BaseModel):
|
||||
]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
if self.system_configuration.credentials is not None:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
copy_credentials = (
|
||||
self.system_configuration.credentials.copy()
|
||||
if self.system_configuration.credentials
|
||||
else {}
|
||||
)
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
restrict_model.model_type
|
||||
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
continue
|
||||
try:
|
||||
custom_model_schema = self.get_model_schema(
|
||||
model_type=restrict_model.model_type,
|
||||
model=restrict_model.model,
|
||||
credentials=copy_credentials,
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
@ -881,15 +890,15 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_custom_provider_models(
|
||||
self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider,
|
||||
model_types: Sequence[ModelType],
|
||||
provider_schema: ProviderEntity,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:param provider_schema: provider schema
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
@ -903,8 +912,10 @@ class ProviderConfiguration(BaseModel):
|
||||
if model_type not in self.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
models = provider_instance.models(model_type)
|
||||
for m in models:
|
||||
for m in provider_schema.models:
|
||||
if m.model_type != model_type:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
load_balancing_enabled = False
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
@ -936,10 +947,10 @@ class ProviderConfiguration(BaseModel):
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_model_schema = provider_instance.get_model_instance(
|
||||
model_configuration.model_type
|
||||
).get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model, model_configuration.credentials
|
||||
custom_model_schema = self.get_model_schema(
|
||||
model_type=model_configuration.model_type,
|
||||
model=model_configuration.model,
|
||||
credentials=model_configuration.credentials,
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f"get custom model schema failed, {ex}")
|
||||
@ -967,7 +978,7 @@ class ProviderConfiguration(BaseModel):
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=custom_model_schema.fetch_from,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
@ -1040,6 +1051,9 @@ class ProviderConfigurations(BaseModel):
|
||||
return list(self.values())
|
||||
|
||||
def __getitem__(self, key):
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
@ -1051,8 +1065,11 @@ class ProviderConfigurations(BaseModel):
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return iter(self.configurations.values())
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.configurations.get(key, default)
|
||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
@ -1061,7 +1078,6 @@ class ProviderModelBundle(BaseModel):
|
||||
"""
|
||||
|
||||
configuration: ProviderConfiguration
|
||||
provider_instance: ModelProvider
|
||||
model_type_instance: AIModel
|
||||
|
||||
# pydantic configs
|
||||
|
@ -1,10 +1,34 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
@ -108,3 +132,55 @@ class ModelSettings(BaseModel):
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class BasicProviderConfig(BaseModel):
|
||||
"""
|
||||
Base model class for common provider settings like credentials
|
||||
"""
|
||||
|
||||
class Type(Enum):
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
type: Type = Field(..., description="The type of the credentials")
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
|
||||
|
||||
class ProviderConfig(BasicProviderConfig):
|
||||
"""
|
||||
Model class for common provider settings like credentials
|
||||
"""
|
||||
|
||||
class Option(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
def to_basic_provider_config(self) -> BasicProviderConfig:
|
||||
return BasicProviderConfig(type=self.type, name=self.name)
|
||||
|
@ -20,6 +20,41 @@ def get_signed_file_url(upload_file_id: str) -> str:
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/upload/for-plugin"
|
||||
|
||||
if user_id is None:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
|
||||
|
||||
|
||||
def verify_plugin_file_signature(
|
||||
*, filename: str, mimetype: str, tenant_id: str, user_id: str | None, timestamp: str, nonce: str, sign: str
|
||||
) -> bool:
|
||||
if user_id is None:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -124,6 +124,17 @@ class File(BaseModel):
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"mime_type": self.mime_type,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
|
69
api/core/file/upload_file_parser.py
Normal file
69
api/core/file/upload_file_parser.py
Normal file
@ -0,0 +1,69 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
17
api/core/helper/download.py
Normal file
17
api/core/helper/download.py
Normal file
@ -0,0 +1,17 @@
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
|
||||
def download_with_size_limit(url, max_download_size: int, **kwargs):
|
||||
response = ssrf_proxy.get(url, follow_redirects=True, **kwargs)
|
||||
if response.status_code == 404:
|
||||
raise ValueError("file not found")
|
||||
|
||||
total_size = 0
|
||||
chunks = []
|
||||
for chunk in response.iter_bytes():
|
||||
total_size += len(chunk)
|
||||
if total_size > max_download_size:
|
||||
raise ValueError("Max file size reached")
|
||||
chunks.append(chunk)
|
||||
content = b"".join(chunks)
|
||||
return content
|
35
api/core/helper/marketplace.py
Normal file
35
api/core/helper/marketplace.py
Normal file
@ -0,0 +1,35 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
@ -1,28 +1,35 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
|
||||
def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEntity, text: str) -> bool:
|
||||
moderation_config = hosting_configuration.moderation_config
|
||||
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
|
||||
if (
|
||||
moderation_config
|
||||
and moderation_config.enabled is True
|
||||
and "openai" in hosting_configuration.provider_map
|
||||
and hosting_configuration.provider_map["openai"].enabled is True
|
||||
and openai_provider_name in hosting_configuration.provider_map
|
||||
and hosting_configuration.provider_map[openai_provider_name].enabled is True
|
||||
):
|
||||
using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type
|
||||
provider_name = model_config.provider
|
||||
if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers:
|
||||
hosting_openai_config = hosting_configuration.provider_map["openai"]
|
||||
assert hosting_openai_config is not None
|
||||
hosting_openai_config = hosting_configuration.provider_map[openai_provider_name]
|
||||
|
||||
if hosting_openai_config.credentials is None:
|
||||
return False
|
||||
|
||||
# 2000 text per chunk
|
||||
length = 2000
|
||||
@ -34,15 +41,20 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str)
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_type_instance = OpenAIModerationModel()
|
||||
# FIXME, for type hint using assert or raise ValueError is better here?
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
)
|
||||
model_type_instance = cast(ModerationModel, model_type_instance)
|
||||
moderation_result = model_type_instance.invoke(
|
||||
model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk
|
||||
model="omni-moderation-latest", credentials=hosting_openai_config.credentials, text=text_chunk
|
||||
)
|
||||
|
||||
if moderation_result is True:
|
||||
return True
|
||||
except Exception as ex:
|
||||
except Exception:
|
||||
logger.exception(f"Fails to check moderation, provider_name: {provider_name}")
|
||||
raise InvokeBadRequestError("Rate limit exceeded, please try again later.")
|
||||
|
||||
|
@ -36,7 +36,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
|
@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
ENDPOINT = "endpoint"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
|
52
api/core/helper/url_signer.py
Normal file
52
api/core/helper/url_signer.py
Normal file
@ -0,0 +1,52 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class SignedUrlParams(BaseModel):
|
||||
sign_key: str = Field(..., description="The sign key")
|
||||
timestamp: str = Field(..., description="Timestamp")
|
||||
nonce: str = Field(..., description="Nonce")
|
||||
sign: str = Field(..., description="Signature")
|
||||
|
||||
|
||||
class UrlSigner:
|
||||
@classmethod
|
||||
def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
|
||||
signed_url_params = cls.get_signed_url_params(sign_key, prefix)
|
||||
return (
|
||||
f"{url}?timestamp={signed_url_params.timestamp}"
|
||||
f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
sign = cls._sign(sign_key, timestamp, nonce, prefix)
|
||||
|
||||
return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)
|
||||
|
||||
@classmethod
|
||||
def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
|
||||
recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)
|
||||
|
||||
return sign == recalculated_sign
|
||||
|
||||
@classmethod
|
||||
def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
|
||||
if not dify_config.SECRET_KEY:
|
||||
raise Exception("SECRET_KEY is not set")
|
||||
|
||||
data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return encoded_sign
|
@ -4,9 +4,9 @@ from flask import Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import QuotaUnit, RestrictModel
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class HostingQuota(BaseModel):
|
||||
@ -48,12 +48,12 @@ class HostingConfiguration:
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
||||
self.provider_map["azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map["openai"] = self.init_openai()
|
||||
self.provider_map["anthropic"] = self.init_anthropic()
|
||||
self.provider_map["minimax"] = self.init_minimax()
|
||||
self.provider_map["spark"] = self.init_spark()
|
||||
self.provider_map["zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/azure_openai/azure_openai"] = self.init_azure_openai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/openai/openai"] = self.init_openai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/anthropic/anthropic"] = self.init_anthropic()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
@ -240,7 +240,14 @@ class HostingConfiguration:
|
||||
@staticmethod
|
||||
def init_moderation_config() -> HostedModerationConfig:
|
||||
if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS:
|
||||
return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(","))
|
||||
providers = dify_config.HOSTED_MODERATION_PROVIDERS.split(",")
|
||||
hosted_providers = []
|
||||
for provider in providers:
|
||||
if "/" not in provider:
|
||||
provider = f"{DEFAULT_PLUGIN_ID}/{provider}/{provider}"
|
||||
hosted_providers.append(provider)
|
||||
|
||||
return HostedModerationConfig(enabled=True, providers=hosted_providers)
|
||||
|
||||
return HostedModerationConfig(enabled=False)
|
||||
|
||||
|
@ -30,7 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from core.tools.utils.rag_web_reader import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -618,10 +618,8 @@ class IndexingRunner:
|
||||
|
||||
tokens = 0
|
||||
if embedding_model_instance:
|
||||
tokens += sum(
|
||||
embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
|
||||
for document in chunk_documents
|
||||
)
|
||||
page_content_list = [document.page_content for document in chunk_documents]
|
||||
tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
|
||||
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents, with_keywords=False)
|
||||
|
@ -48,7 +48,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
),
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
@ -101,7 +101,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
@ -110,7 +110,7 @@ class LLMGenerator:
|
||||
questions = output_parser.parse(cast(str, response.message.content))
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
@ -150,7 +150,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,7 +200,7 @@ class LLMGenerator:
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
except InvokeError as e:
|
||||
@ -236,7 +236,7 @@ class LLMGenerator:
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
@ -248,7 +248,7 @@ class LLMGenerator:
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
@ -301,7 +301,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Optional, Union, cast
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
@ -98,6 +98,42 @@ class ModelInstance:
|
||||
|
||||
return None
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[True] = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: Literal[False] = False,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
@ -192,7 +228,7 @@ class ModelInstance:
|
||||
),
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
|
||||
@ -204,7 +240,7 @@ class ModelInstance:
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
int,
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
@ -397,7 +433,7 @@ class ModelManager:
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Return first provider and the first model in the provider
|
||||
:param tenant_id: tenant id
|
||||
|
@ -18,7 +18,6 @@ class ModelType(Enum):
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
@ -37,8 +36,6 @@ class ModelType(Enum):
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type in {"tts", cls.TTS.value}:
|
||||
return cls.TTS
|
||||
elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
|
||||
return cls.TEXT2IMG
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
@ -62,8 +59,6 @@ class ModelType(Enum):
|
||||
return "tts"
|
||||
elif self == self.MODERATION:
|
||||
return "moderation"
|
||||
elif self == self.TEXT2IMG:
|
||||
return "text2img"
|
||||
else:
|
||||
raise ValueError(f"invalid model type {self}")
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user