mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 07:15:54 +08:00
Merge branch 'feat/r2' into deploy/dev
This commit is contained in:
commit
78387c1e3d
@ -1,5 +1,4 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12
|
||||
|
||||
# [Optional] Uncomment this section to install additional OS packages.
|
||||
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
1
.github/workflows/style.yml
vendored
1
.github/workflows/style.yml
vendored
@ -139,6 +139,7 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
|
||||
@ -87,8 +87,6 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
@ -54,8 +54,6 @@
|
||||
|
||||
**1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر.
|
||||
|
||||
<https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa>
|
||||
|
||||
**2. الدعم الشامل للنماذج**: تكامل سلس مع مئات من LLMs الخاصة / مفتوحة المصدر من عشرات من موفري التحليل والحلول المستضافة ذاتيًا، مما يغطي GPT و Mistral و Llama3 وأي نماذج متوافقة مع واجهة OpenAI API. يمكن العثور على قائمة كاملة بمزودي النموذج المدعومين [هنا](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">ডিফাই ওয়ার্কফ্লো ফাইল আপলোড পরিচিতি: গুগল নোটবুক-এলএম পডকাস্ট পুনর্নির্মাণ</a>
|
||||
@ -84,8 +84,6 @@ docker compose up -d
|
||||
**১. ওয়ার্কফ্লো**:
|
||||
ভিজ্যুয়াল ক্যানভাসে AI ওয়ার্কফ্লো তৈরি এবং পরীক্ষা করুন, নিম্নলিখিত সব ফিচার এবং তার বাইরেও আরও অনেক কিছু ব্যবহার করে।
|
||||
|
||||
<https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa>
|
||||
|
||||
**২. মডেল সাপোর্ট**:
|
||||
GPT, Mistral, Llama3, এবং যেকোনো OpenAI API-সামঞ্জস্যপূর্ণ মডেলসহ, কয়েক ডজন ইনফারেন্স প্রদানকারী এবং সেল্ফ-হোস্টেড সমাধান থেকে শুরু করে প্রোপ্রাইটরি/ওপেন-সোর্স LLM-এর সাথে সহজে ইন্টিগ্রেশন। সমর্থিত মডেল প্রদানকারীদের একটি সম্পূর্ণ তালিকা পাওয়া যাবে [এখানে](https://docs.dify.ai/getting-started/readme/model-providers)।
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<div align="center">
|
||||
<a href="https://cloud.dify.ai">Dify 云服务</a> ·
|
||||
@ -61,11 +61,6 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
**1. 工作流**:
|
||||
在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. 全面的模型支持**:
|
||||
与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Einführung in Dify Workflow File Upload: Google NotebookLM Podcast nachbilden</a>
|
||||
@ -83,11 +83,6 @@ Bitte beachten Sie unsere [FAQ](https://docs.dify.ai/getting-started/install-sel
|
||||
**1. Workflow**:
|
||||
Erstellen und testen Sie leistungsstarke KI-Workflows auf einer visuellen Oberfläche, wobei Sie alle der folgenden Funktionen und darüber hinaus nutzen können.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Umfassende Modellunterstützung**:
|
||||
Nahtlose Integration mit Hunderten von proprietären und Open-Source-LLMs von Dutzenden Inferenzanbietern und selbstgehosteten Lösungen, die GPT, Mistral, Llama3 und alle mit der OpenAI API kompatiblen Modelle abdecken. Eine vollständige Liste der unterstützten Modellanbieter finden Sie [hier](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
@ -59,11 +59,6 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto.
|
||||
**1. Flujo de trabajo**:
|
||||
Construye y prueba potentes flujos de trabajo de IA en un lienzo visual, aprovechando todas las siguientes características y más.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Soporte de modelos completo**:
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
@ -59,11 +59,6 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
|
||||
**1. Flux de travail** :
|
||||
Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Prise en charge complète des modèles** :
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
@ -60,11 +60,6 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
**1. ワークフロー**:
|
||||
強力なAIワークフローをビジュアルキャンバス上で構築し、テストできます。すべての機能、および以下の機能を使用できます。
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. 総合的なモデルサポート**:
|
||||
数百ものプロプライエタリ/オープンソースのLLMと、数十もの推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、OpenAI APIと互換性のあるすべてのモデルを統合されています。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
@ -59,11 +59,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify 클라우드</a> ·
|
||||
@ -54,11 +54,6 @@
|
||||
**1. 워크플로우**:
|
||||
다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. 포괄적인 모델 지원:**:
|
||||
|
||||
수십 개의 추론 제공업체와 자체 호스팅 솔루션에서 제공하는 수백 개의 독점 및 오픈 소스 LLM과 원활하게 통합되며, GPT, Mistral, Llama3 및 모든 OpenAI API 호환 모델을 포함합니다. 지원되는 모델 제공업체의 전체 목록은 [여기](https://docs.dify.ai/getting-started/readme/model-providers)에서 확인할 수 있습니다.
|
||||
|
@ -1,5 +1,4 @@
|
||||

|
||||
|
||||

|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM</a>
|
||||
</p>
|
||||
@ -59,11 +58,6 @@ Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto.
|
||||
**1. Workflow**:
|
||||
Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Suporte abrangente a modelos**:
|
||||
Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Predstavljamo nalaganje datotek Dify Workflow: znova ustvarite Google NotebookLM Podcast</a>
|
||||
@ -81,11 +81,6 @@ Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-star
|
||||
**1. Potek dela**:
|
||||
Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Celovita podpora za modele**:
|
||||
Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Bulut</a> ·
|
||||
@ -55,11 +55,6 @@ Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel aray
|
||||
**1. Workflow**:
|
||||
Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edin, aşağıdaki tüm özellikleri ve daha fazlasını kullanarak.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Kapsamlı model desteği**:
|
||||
Çok sayıda çıkarım sağlayıcısı ve kendi kendine barındırılan çözümlerden yüzlerce özel / açık kaynaklı LLM ile sorunsuz entegrasyon sağlar. GPT, Mistral, Llama3 ve OpenAI API uyumlu tüm modelleri kapsar. Desteklenen model sağlayıcılarının tam listesine [buradan](https://docs.dify.ai/getting-started/readme/model-providers) ulaşabilirsiniz.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">介紹 Dify 工作流程檔案上傳功能:重現 Google NotebookLM Podcast</a>
|
||||
@ -86,8 +86,6 @@ docker compose up -d
|
||||
**1. 工作流程**:
|
||||
在視覺化畫布上建立和測試強大的 AI 工作流程,利用以下所有功能及更多。
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
**2. 全面的模型支援**:
|
||||
無縫整合來自數十個推理提供商和自託管解決方案的數百個專有/開源 LLM,涵蓋 GPT、Mistral、Llama3 和任何與 OpenAI API 兼容的模型。您可以在[此處](https://docs.dify.ai/getting-started/readme/model-providers)找到支援的模型提供商完整列表。
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
@ -55,11 +55,6 @@ Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Gia
|
||||
**1. Quy trình làm việc**:
|
||||
Xây dựng và kiểm tra các quy trình làm việc AI mạnh mẽ trên một canvas trực quan, tận dụng tất cả các tính năng sau đây và hơn thế nữa.
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Hỗ trợ mô hình toàn diện**:
|
||||
Tích hợp liền mạch với hàng trăm mô hình LLM độc quyền / mã nguồn mở từ hàng chục nhà cung cấp suy luận và giải pháp tự lưu trữ, bao gồm GPT, Mistral, Llama3, và bất kỳ mô hình tương thích API OpenAI nào. Danh sách đầy đủ các nhà cung cấp mô hình được hỗ trợ có thể được tìm thấy [tại đây](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||
|
@ -348,6 +348,7 @@ SENTRY_DSN=
|
||||
|
||||
# DEBUG
|
||||
DEBUG=false
|
||||
ENABLE_REQUEST_LOGGING=False
|
||||
SQLALCHEMY_ECHO=false
|
||||
|
||||
# Notion import configuration, support public and internal
|
||||
|
@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_request_logging,
|
||||
ext_sentry,
|
||||
ext_set_secretkey,
|
||||
ext_storage,
|
||||
@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_blueprints,
|
||||
ext_commands,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
@ -17,6 +17,12 @@ class DeploymentConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Request logging configuration
|
||||
ENABLE_REQUEST_LOGGING: bool = Field(
|
||||
description="Enable request and response body logging",
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
|
@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
default="database",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.3.1",
|
||||
default="1.4.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -17,15 +17,13 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
)
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
@ -75,7 +73,17 @@ class AppListApi(Resource):
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids)
|
||||
if len(res) != len(app_ids):
|
||||
raise BadRequest("Invalid app id in webapp auth")
|
||||
|
||||
for app in app_pagination.items:
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -119,6 +127,10 @@ class AppApi(Resource):
|
||||
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
return app_model
|
||||
|
||||
@setup_required
|
||||
|
@ -81,8 +81,7 @@ class DraftWorkflowApi(Resource):
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
# TODO: set this to required=True after frontend is updated
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
|
@ -1,3 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
@ -12,8 +15,7 @@ from fields.workflow_run_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models import Account, App, AppMode, EndUser
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
run_id = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app_model,
|
||||
run_id=run_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
@ -24,7 +24,7 @@ from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@ -119,6 +119,9 @@ class ForgotPasswordResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -168,6 +171,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except WorkspacesLimitExceededError:
|
||||
pass
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
|
@ -21,6 +21,7 @@ from controllers.console.error import (
|
||||
AccountNotFound,
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
WorkspacesLimitExceeded,
|
||||
)
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
@ -30,7 +31,7 @@ from models.account import Account
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@ -88,6 +89,11 @@ class LoginApi(Resource):
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
system_features = FeatureService.get_system_features()
|
||||
|
||||
if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available():
|
||||
raise WorkspacesLimitExceeded()
|
||||
else:
|
||||
return {
|
||||
"result": "fail",
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
@ -198,6 +204,9 @@ class EmailCodeLoginApi(Resource):
|
||||
if account:
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant:
|
||||
workspaces = FeatureService.get_system_features().license.workspaces
|
||||
if not workspaces.is_available():
|
||||
raise WorkspacesLimitExceeded()
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
else:
|
||||
@ -215,6 +224,8 @@ class EmailCodeLoginApi(Resource):
|
||||
return NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
@ -101,7 +101,9 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
|
||||
|
@ -43,7 +43,7 @@ from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource):
|
||||
if "application/json" in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
|
||||
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
@ -102,8 +101,8 @@ class DraftRagPipelineApi(Resource):
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
if not isinstance(data.get("graph"), dict):
|
||||
raise ValueError("graph is not a dict")
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource):
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"pipeline_variables": data.get("pipeline_variables"),
|
||||
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
@ -130,21 +129,15 @@ class DraftRagPipelineApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
pipeline_variables_list = args.get("pipeline_variables") or {}
|
||||
pipeline_variables = {
|
||||
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
|
||||
for k, v in pipeline_variables_list.items()
|
||||
}
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=args["graph"],
|
||||
features=args["features"],
|
||||
unique_hash=args.get("hash"),
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
pipeline_variables=pipeline_variables,
|
||||
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
@ -476,7 +469,7 @@ class RagPipelineConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
def get(self, pipeline_id):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
@ -636,12 +629,15 @@ class RagPipelineSecondStepApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider = request.args.get("datasource_provider", required=True, type=str)
|
||||
node_id = request.args.get("node_id", required=True, type=str)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return rag_pipeline_service.get_second_step_parameters(
|
||||
pipeline=pipeline, datasource_provider=datasource_provider
|
||||
variables = rag_pipeline_service.get_second_step_parameters(
|
||||
pipeline=pipeline, node_id=node_id
|
||||
)
|
||||
return {
|
||||
"variables": variables,
|
||||
}
|
||||
|
||||
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@ -713,14 +709,7 @@ class DatasourceListApi(Resource):
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in BuiltinToolManageService.list_rag_pipeline_datasources(
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
)
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
||||
api.add_resource(
|
||||
@ -792,5 +781,9 @@ api.add_resource(
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceListApi,
|
||||
"/rag/pipelines/datasources",
|
||||
"/rag/pipelines/datasource-plugins",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSecondStepApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/processing/paramters",
|
||||
)
|
||||
|
@ -46,6 +46,18 @@ class NotAllowedCreateWorkspace(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class WorkspaceMembersLimitExceeded(BaseHTTPException):
|
||||
error_code = "limit_exceeded"
|
||||
description = "Unable to add member because the maximum workspace's member limit was exceeded"
|
||||
code = 400
|
||||
|
||||
|
||||
class WorkspacesLimitExceeded(BaseHTTPException):
|
||||
error_code = "limit_exceeded"
|
||||
description = "Unable to create workspace because the maximum workspace limit was exceeded"
|
||||
code = 400
|
||||
|
||||
|
||||
class AccountBannedError(BaseHTTPException):
|
||||
error_code = "account_banned"
|
||||
description = "Account is banned."
|
||||
|
@ -23,3 +23,9 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
|
||||
error_code = "app_suggested_questions_after_answer_disabled"
|
||||
description = "Function Suggested questions after answer disabled."
|
||||
code = 403
|
||||
|
||||
|
||||
class AppAccessDeniedError(BaseHTTPException):
|
||||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
@ -15,6 +16,11 @@ from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.login import login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InstalledAppsListApi(Resource):
|
||||
@ -48,6 +54,21 @@ class InstalledAppsListApi(Resource):
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
|
||||
# filter out apps that user doesn't have access to
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
for installed_app in installed_app_list:
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_code=app_code,
|
||||
):
|
||||
res.append(installed_app)
|
||||
installed_app_list = res
|
||||
logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}")
|
||||
|
||||
installed_app_list.sort(
|
||||
key=lambda app: (
|
||||
-app["is_pinned"],
|
||||
|
@ -4,10 +4,14 @@ from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import InstalledApp
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def installed_app_required(view=None):
|
||||
@ -48,6 +52,36 @@ def installed_app_required(view=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=str(current_user.id),
|
||||
app_code=app_code,
|
||||
)
|
||||
if not res:
|
||||
raise AppAccessDeniedError()
|
||||
|
||||
return view(installed_app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
method_decorators = [installed_app_required, account_initialization_required, login_required]
|
||||
|
||||
method_decorators = [
|
||||
user_allowed_to_access_app,
|
||||
installed_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
@ -6,6 +6,7 @@ from flask_restful import Resource, abort, marshal_with, reqparse
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.error import WorkspaceMembersLimitExceeded
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
@ -17,6 +18,7 @@ from libs.login import login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class MemberListApi(Resource):
|
||||
@ -54,6 +56,12 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
|
||||
|
||||
if not workspace_members.is_available(len(invitee_emails)):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
token = RegisterService.invite_new_member(
|
||||
|
@ -5,5 +5,6 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from . import mail
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
|
27
api/controllers/inner_api/mail.py
Normal file
27
api/controllers/inner_api/mail.py
Normal file
@ -0,0 +1,27 @@
|
||||
from flask_restful import (
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from services.enterprise.mail_service import DifyMail, EnterpriseMailService
|
||||
|
||||
|
||||
class EnterpriseMail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("to", type=str, action="append", required=True)
|
||||
parser.add_argument("subject", type=str, required=True)
|
||||
parser.add_argument("body", type=str, required=True)
|
||||
parser.add_argument("substitutions", type=dict, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
EnterpriseMailService.send_mail(DifyMail(**args))
|
||||
return {"message": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(EnterpriseMail, "/enterprise/mail")
|
@ -1,12 +1,15 @@
|
||||
from flask_restful import marshal_with
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
@ -40,5 +43,51 @@ class AppMeta(WebApiResource):
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
class AppAccessMode(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id = args["appId"]
|
||||
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
|
||||
return {"accessMode": res.access_mode}
|
||||
|
||||
|
||||
class AppWebAuthPermission(Resource):
|
||||
def get(self):
|
||||
user_id = "visitor"
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
raise
|
||||
if " " not in auth_header:
|
||||
raise
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise
|
||||
|
||||
decoded = PassportService().verify(tk)
|
||||
user_id = decoded.get("user_id", "visitor")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id = args["appId"]
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
api.add_resource(AppMeta, "/meta")
|
||||
# webapp auth apis
|
||||
api.add_resource(AppAccessMode, "/webapp/access-mode")
|
||||
api.add_resource(AppWebAuthPermission, "/webapp/permission")
|
||||
|
@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class WebSSOAuthRequiredError(BaseHTTPException):
|
||||
class WebAppAuthRequiredError(BaseHTTPException):
|
||||
error_code = "web_sso_auth_required"
|
||||
description = "Web SSO authentication required."
|
||||
description = "Web app authentication required."
|
||||
code = 401
|
||||
|
||||
|
||||
class WebAppAuthAccessDeniedError(BaseHTTPException):
|
||||
error_code = "web_app_access_denied"
|
||||
description = "You do not have permission to access this web app."
|
||||
code = 401
|
||||
|
||||
|
||||
|
120
api/controllers/web/login.py
Normal file
120
api/controllers/web/login.py
Normal file
@ -0,0 +1,120 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.wraps import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
raise EmailOrPasswordMismatchError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
raise AccountNotFound()
|
||||
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
return {"result": "success", "token": token}
|
||||
|
||||
|
||||
# class LogoutApi(Resource):
|
||||
# @setup_required
|
||||
# def get(self):
|
||||
# account = cast(Account, flask_login.current_user)
|
||||
# if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
# return {"result": "success"}
|
||||
# flask_login.logout_user()
|
||||
# return {"result": "success"}
|
||||
|
||||
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
if not account:
|
||||
raise AccountNotFound()
|
||||
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "token": token}
|
||||
|
||||
|
||||
# api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
@ -5,7 +5,7 @@ from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import WebSSOAuthRequiredError
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
@ -24,10 +24,10 @@ class PassportResource(Resource):
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
if system_features.sso_enforced_for_web:
|
||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
||||
if app_web_sso_enabled:
|
||||
raise WebSSOAuthRequiredError()
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
# get site from db and check if it is normal
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
|
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebSSOAuthRequiredError
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
@ -29,7 +29,7 @@ def validate_jwt_token(view=None):
|
||||
|
||||
def decode_jwt_token():
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
app_code = str(request.headers.get("X-App-Code"))
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
@ -57,35 +57,53 @@ def decode_jwt_token():
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
_validate_web_sso_token(decoded, system_features, app_code)
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
|
||||
return app_model, end_user
|
||||
except Unauthorized as e:
|
||||
if system_features.sso_enforced_for_web:
|
||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
||||
if app_web_sso_enabled:
|
||||
raise WebSSOAuthRequiredError()
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
|
||||
)
|
||||
if app_web_auth_enabled:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
raise Unauthorized(e.description)
|
||||
|
||||
|
||||
def _validate_web_sso_token(decoded, system_features, app_code):
|
||||
app_web_sso_enabled = False
|
||||
|
||||
# Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
|
||||
if system_features.sso_enforced_for_web:
|
||||
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
|
||||
if app_web_sso_enabled:
|
||||
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
# Check if authentication is enforced for web app, and if the token source is not webapp,
|
||||
# raise an error and redirect to login
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "sso":
|
||||
raise WebSSOAuthRequiredError()
|
||||
if not source or source != "webapp":
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO,
|
||||
# Check if authentication is not enforced for web, and if the token source is webapp,
|
||||
# raise an error and redirect to normal passport login
|
||||
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
||||
if not system_webapp_auth_enabled or not app_web_auth_enabled:
|
||||
source = decoded.get("token_source")
|
||||
if source and source == "sso":
|
||||
raise Unauthorized("sso token expired.")
|
||||
if source and source == "webapp":
|
||||
raise Unauthorized("webapp token expired.")
|
||||
|
||||
|
||||
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
# Check if the user is allowed to access the web app
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
|
||||
class WebApiResource(Resource):
|
||||
|
@ -29,9 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
@ -165,8 +163,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@ -231,8 +230,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@ -295,8 +295,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
@ -70,7 +70,7 @@ from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRunStatus,
|
||||
@ -105,11 +105,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if isinstance(user, EndUser):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatedByRole.END_USER
|
||||
self._created_by_role = CreatorUserRole.END_USER
|
||||
elif isinstance(user, Account):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatedByRole.ACCOUNT
|
||||
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
@ -739,9 +739,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
url=file["remote_url"],
|
||||
belongs_to="assistant",
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatedByRole.ACCOUNT
|
||||
created_by_role=CreatorUserRole.ACCOUNT
|
||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER,
|
||||
else CreatorUserRole.END_USER,
|
||||
created_by=message.from_account_id or message.from_end_user_id or "",
|
||||
)
|
||||
for file in self._recorded_files
|
||||
|
@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
|
@ -27,7 +27,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow
|
||||
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -138,10 +138,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@ -262,10 +264,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@ -325,10 +329,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
@ -4,7 +4,6 @@ from typing import Any, Optional, TextIO, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@ -114,35 +113,6 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
color=self.color,
|
||||
)
|
||||
|
||||
def on_datasource_end(
|
||||
self,
|
||||
datasource_name: str,
|
||||
datasource_inputs: Mapping[str, Any],
|
||||
datasource_outputs: Iterable[DatasourceInvokeMessage] | str,
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> None:
|
||||
"""Run on datasource end."""
|
||||
if dify_config.DEBUG:
|
||||
print_text("\n[on_datasource_end]\n", color=self.color)
|
||||
print_text("Datasource: " + datasource_name + "\n", color=self.color)
|
||||
print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color)
|
||||
print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASOURCE_TRACE,
|
||||
message_id=message_id,
|
||||
datasource_name=datasource_name,
|
||||
datasource_inputs=datasource_inputs,
|
||||
datasource_outputs=datasource_outputs,
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
|
@ -1,12 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
@ -16,7 +13,6 @@ class DatasourcePlugin:
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[DatasourceParameter]]
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
@ -33,49 +29,41 @@ class DatasourcePlugin:
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
return DatasourceProviderType.RAG_PIPELINE
|
||||
|
||||
def _invoke_first_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
yield from manager.invoke_first_step(
|
||||
return manager.invoke_first_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
def _invoke_second_step(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
return manager.invoke_second_step(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
@ -86,28 +74,3 @@ class DatasourcePlugin:
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> list[DatasourceParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
if not self.entity.has_runtime_parameters:
|
||||
return self.entity.parameters
|
||||
|
||||
if self.runtime_parameters is not None:
|
||||
return self.runtime_parameters
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
self.runtime_parameters = manager.get_runtime_parameters(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="",
|
||||
provider=self.entity.identity.provider,
|
||||
datasource=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
rag_pipeline_id=rag_pipeline_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
||||
|
@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
@ -22,15 +22,6 @@ class DatasourcePluginProviderController:
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.RAG_PIPELINE
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
|
@ -1,224 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Generator, Iterable
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceInvokeMessageBinary,
|
||||
)
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
class DatasourceEngine:
|
||||
"""
|
||||
Datasource runtime engine take care of the datasource executions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def invoke_first_step(
|
||||
datasource: DatasourcePlugin,
|
||||
datasource_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the datasource with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_datasource_start(
|
||||
datasource_name=datasource.entity.identity.name, datasource_inputs=datasource_parameters
|
||||
)
|
||||
|
||||
if datasource.runtime and datasource.runtime.runtime_parameters:
|
||||
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
|
||||
|
||||
response = datasource._invoke_first_step(
|
||||
user_id=user_id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# hit the callback handler
|
||||
response = workflow_tool_callback.on_datasource_end(
|
||||
datasource_name=datasource.entity.identity.name,
|
||||
datasource_inputs=datasource_parameters,
|
||||
datasource_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def invoke_second_step(
|
||||
datasource: DatasourcePlugin,
|
||||
datasource_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the datasource with the given arguments.
|
||||
"""
|
||||
try:
|
||||
response = datasource._invoke_second_step(
|
||||
user_id=user_id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle datasource response
|
||||
"""
|
||||
result = ""
|
||||
for response in datasource_response:
|
||||
if response.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||
result += cast(DatasourceInvokeMessage.TextMessage, response.message).text
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
result += (
|
||||
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
result += (
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||
result = json.dumps(
|
||||
cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_datasource_response_binary_and_text(
|
||||
datasource_response: list[DatasourceInvokeMessage],
|
||||
) -> Generator[DatasourceInvokeMessageBinary, None, None]:
|
||||
"""
|
||||
Extract datasource response binary
|
||||
"""
|
||||
for response in datasource_response:
|
||||
if response.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
if response.meta.get("mime_type"):
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
try:
|
||||
url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and "mime_type" in response.meta:
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream",
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
datasource_messages: Iterable[DatasourceInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:return: message file ids
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in datasource_messages:
|
||||
if "image" in message.mimetype:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in message.mimetype:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in message.mimetype:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
|
||||
# extract tool file id from url
|
||||
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
CreatedByRole.ACCOUNT
|
||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append(message_file.id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
@ -6,9 +6,8 @@ import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,7 +35,7 @@ class DatasourceManager:
|
||||
if provider in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider]
|
||||
|
||||
manager = PluginToolManager()
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
@ -55,7 +54,6 @@ class DatasourceManager:
|
||||
@classmethod
|
||||
def get_datasource_runtime(
|
||||
cls,
|
||||
provider_type: DatasourceProviderType,
|
||||
provider_id: str,
|
||||
datasource_name: str,
|
||||
tenant_id: str,
|
||||
@ -70,18 +68,15 @@ class DatasourceManager:
|
||||
|
||||
:return: the datasource plugin
|
||||
"""
|
||||
if provider_type == DatasourceProviderType.RAG_PIPELINE:
|
||||
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
|
||||
else:
|
||||
raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
@classmethod
|
||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
||||
"""
|
||||
list all the datasource providers
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_datasources(tenant_id)
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entities = manager.fetch_datasource_providers(tenant_id)
|
||||
return [
|
||||
DatasourcePluginProviderController(
|
||||
entity=provider.declaration,
|
||||
|
@ -4,7 +4,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
@ -14,7 +13,7 @@ class DatasourceApiEntity(BaseModel):
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
DATASOURCE_SELECTOR_MODEL_IDENTITY = "__dify__datasource_selector__"
|
@ -1,13 +1,9 @@
|
||||
import base64
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
@ -17,25 +13,7 @@ from core.plugin.entities.parameters import (
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
|
||||
|
||||
class DatasourceProviderType(enum.StrEnum):
|
||||
@ -43,7 +21,9 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
Enum class for datasource provider
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE = "website"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
@ -59,153 +39,6 @@ class DatasourceProviderType(enum.StrEnum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class DatasourceInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: Any = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values) -> Any:
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
value = values.get("variable_value")
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return values
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
if value in {"json", "text", "files"}:
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
|
||||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
try:
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
|
||||
return v
|
||||
|
||||
|
||||
class DatasourceInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DatasourceParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
@ -223,8 +56,6 @@ class DatasourceParameter(PluginParameter):
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
@ -235,21 +66,13 @@ class DatasourceParameter(PluginParameter):
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class DatasourceParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: DatasourceParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
description: I18nObject = Field(..., description="The description of the parameter")
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
@ -277,30 +100,16 @@ class DatasourceParameter(PluginParameter):
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=typ,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
@ -327,26 +136,18 @@ class DatasourceEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
class DatasourceProviderEntity(ToolProviderEntity):
|
||||
"""
|
||||
Datasource provider entity
|
||||
"""
|
||||
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(ToolProviderEntity):
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: DatasourceParameter.DatasourceParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
@ -394,24 +195,3 @@ class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class DatasourceSelector(BaseModel):
|
||||
dify_model_identity: str = DATASOURCE_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
type: DatasourceParameter.DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
required: bool = Field(..., description="Whether the parameter is required")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
default: Optional[Union[int, float, str]] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
datasource_name: str = Field(..., description="The name of the datasource")
|
||||
datasource_description: str = Field(..., description="The description of the datasource")
|
||||
datasource_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
datasource_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
@ -1,111 +0,0 @@
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(
|
||||
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
|
||||
),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(
|
||||
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
|
||||
),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(
|
||||
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
|
||||
),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(
|
||||
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
|
||||
),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(
|
||||
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
|
||||
),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(
|
||||
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
|
||||
),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(
|
||||
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
|
||||
),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(
|
||||
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
|
||||
),
|
||||
ToolLabelEnum.NEWS: ToolLabel(
|
||||
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
|
||||
),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(
|
||||
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
|
||||
),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
|
||||
name="productivity",
|
||||
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
|
||||
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
|
||||
),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(
|
||||
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
|
||||
),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(
|
||||
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
|
||||
),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
|
||||
name="entertainment",
|
||||
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
|
||||
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
|
||||
),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(
|
||||
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
|
||||
),
|
||||
ToolLabelEnum.OTHER: ToolLabel(
|
||||
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
|
||||
),
|
||||
}
|
||||
|
||||
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
|
||||
default_tool_label_name_list = [label.name for label in default_tool_labels]
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel):
|
||||
description="The status message of the span. Additional field for context of the event. E.g. the error "
|
||||
"message of an error event.",
|
||||
)
|
||||
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
|
||||
default=None, description="The input of the span. Can be any JSON object."
|
||||
)
|
||||
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
|
||||
default=None, description="The output of the span. Can be any JSON object."
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
|
@ -1,11 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from langfuse import Langfuse # type: ignore
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=trace_info.tenant_id
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
model_provider = process_data.get("model_provider", None)
|
||||
model_name = process_data.get("model_name", None)
|
||||
if model_provider is not None and model_name is not None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel):
|
||||
|
||||
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
name: Optional[str] = Field(..., description="Name of the run")
|
||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
||||
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
|
||||
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
|
||||
run_type: LangSmithRunType = Field(..., description="Type of the run")
|
||||
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
|
||||
end_time: Optional[datetime | str] = Field(None, description="End time of the run")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
@ -7,7 +6,7 @@ from typing import Optional, cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
metadata = execution_metadata.copy()
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
elif node_type == "knowledge-retrieval":
|
||||
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
run_type = LangSmithRunType.retriever
|
||||
else:
|
||||
run_type = LangSmithRunType.tool
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
@ -7,7 +6,7 @@ from typing import Optional, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -150,8 +151,29 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
@ -161,26 +183,22 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
metadata = execution_metadata.copy()
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
@ -193,7 +211,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
@ -226,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
|
@ -287,7 +287,9 @@ class OpsTraceManager:
|
||||
:return:
|
||||
"""
|
||||
# auth check
|
||||
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
|
||||
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
|
||||
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
@ -7,6 +6,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
|
||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
node_total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
attributes = execution_metadata.copy()
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
|
@ -64,9 +64,9 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": execution.inputs_dict,
|
||||
"outputs": execution.outputs_dict,
|
||||
"process_data": execution.process_data_dict,
|
||||
"inputs": execution.inputs,
|
||||
"outputs": execution.outputs,
|
||||
"process_data": execution.process_data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -113,7 +113,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": execution.inputs_dict,
|
||||
"outputs": execution.outputs_dict,
|
||||
"process_data": execution.process_data_dict,
|
||||
"inputs": execution.inputs,
|
||||
"outputs": execution.outputs,
|
||||
"process_data": execution.process_data,
|
||||
}
|
||||
|
@ -1,16 +1,16 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
@ -27,7 +27,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginToolProviderEntity],
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
@ -36,12 +36,12 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
for datasource in provider.declaration.datasources:
|
||||
datasource.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
@ -58,7 +58,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
PluginToolProviderEntity,
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
@ -66,8 +66,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
for datasource in response.declaration.datasources:
|
||||
datasource.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
@ -79,7 +79,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
@ -88,8 +88,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages",
|
||||
ToolInvokeMessage,
|
||||
f"plugin/{tenant_id}/dispatch/datasource/first_step",
|
||||
dict,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
@ -104,7 +104,10 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise Exception("No response from plugin daemon")
|
||||
|
||||
def invoke_second_step(
|
||||
self,
|
||||
@ -114,7 +117,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: dict[str, Any],
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
@ -123,8 +126,8 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/invoke_second_step",
|
||||
ToolInvokeMessage,
|
||||
f"plugin/{tenant_id}/dispatch/datasource/second_step",
|
||||
dict,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
@ -139,7 +142,10 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise Exception("No response from plugin daemon")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
@ -151,7 +157,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
@ -170,48 +176,3 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters of the datasource
|
||||
"""
|
||||
datasource_provider_id = GenericProviderID(provider)
|
||||
|
||||
class RuntimeParametersResponse(BaseModel):
|
||||
parameters: list[ToolParameter]
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_runtime_parameters",
|
||||
RuntimeParametersResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.parameters
|
||||
|
||||
return []
|
||||
|
@ -3,10 +3,9 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
PluginToolProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
@ -45,67 +44,6 @@ class PluginToolManager(BasePluginClient):
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasources for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginToolProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
datasource_provider_id = DatasourceProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = datasource_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasource",
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceSelector
|
||||
from core.file.models import File
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
|
||||
@ -19,10 +18,4 @@ def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str,
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
elif isinstance(parameter, DatasourceSelector):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, DatasourceSelector) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
return parameters
|
||||
|
@ -6,6 +6,12 @@ from urllib.parse import urljoin
|
||||
import requests
|
||||
from requests import Response
|
||||
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
WaterCrawlBadRequestError,
|
||||
WaterCrawlPermissionError,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
@ -53,6 +59,15 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
yield data
|
||||
|
||||
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
|
||||
if response.status_code == 401:
|
||||
raise WaterCrawlAuthenticationError(response)
|
||||
|
||||
if response.status_code == 403:
|
||||
raise WaterCrawlPermissionError(response)
|
||||
|
||||
if 400 <= response.status_code < 500:
|
||||
raise WaterCrawlBadRequestError(response)
|
||||
|
||||
response.raise_for_status()
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
|
32
api/core/rag/extractor/watercrawl/exceptions.py
Normal file
32
api/core/rag/extractor/watercrawl/exceptions.py
Normal file
@ -0,0 +1,32 @@
|
||||
import json
|
||||
|
||||
|
||||
class WaterCrawlError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WaterCrawlBadRequestError(WaterCrawlError):
|
||||
def __init__(self, response):
|
||||
self.status_code = response.status_code
|
||||
self.response = response
|
||||
data = response.json()
|
||||
self.message = data.get("message", "Unknown error occurred")
|
||||
self.errors = data.get("errors", {})
|
||||
super().__init__(self.message)
|
||||
|
||||
@property
|
||||
def flat_errors(self):
|
||||
return json.dumps(self.errors)
|
||||
|
||||
def __str__(self):
|
||||
return f"WaterCrawlBadRequestError: {self.message} \n {self.flat_errors}"
|
||||
|
||||
|
||||
class WaterCrawlPermissionError(WaterCrawlBadRequestError):
|
||||
def __str__(self):
|
||||
return f"You are exceeding your WaterCrawl API limits. {self.message}"
|
||||
|
||||
|
||||
class WaterCrawlAuthenticationError(WaterCrawlBadRequestError):
|
||||
def __str__(self):
|
||||
return "WaterCrawl API key is invalid or expired. Please check your API key and try again."
|
@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor):
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatedByRole.ACCOUNT,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
|
@ -2,16 +2,30 @@
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,16 +37,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||
Each method creates its own session, handles the transaction, and commits changes
|
||||
to the database. This prevents long-running connections in the workflow core.
|
||||
|
||||
This implementation also includes an in-memory cache for node executions to improve
|
||||
performance by reducing database queries.
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
|
||||
):
|
||||
"""
|
||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Optional app ID for filtering by application
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
@ -44,38 +68,168 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# Initialize in-memory cache for node executions
|
||||
# Key: node_execution_id, Value: NodeExecution
|
||||
self._node_execution_cache: dict[str, NodeExecution] = {}
|
||||
|
||||
def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance and commit changes to the database.
|
||||
Convert a database model to a domain model.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
db_model: The database model to convert
|
||||
|
||||
Returns:
|
||||
The domain model
|
||||
"""
|
||||
# Parse JSON fields
|
||||
inputs = db_model.inputs_dict
|
||||
process_data = db_model.process_data_dict
|
||||
outputs = db_model.outputs_dict
|
||||
metadata = db_model.execution_metadata_dict
|
||||
|
||||
# Convert status to domain enum
|
||||
status = NodeExecutionStatus(db_model.status)
|
||||
|
||||
return NodeExecution(
|
||||
id=db_model.id,
|
||||
node_execution_id=db_model.node_execution_id,
|
||||
workflow_id=db_model.workflow_id,
|
||||
workflow_run_id=db_model.workflow_run_id,
|
||||
index=db_model.index,
|
||||
predecessor_node_id=db_model.predecessor_node_id,
|
||||
node_id=db_model.node_id,
|
||||
node_type=NodeType(db_model.node_type),
|
||||
title=db_model.title,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
status=status,
|
||||
error=db_model.error,
|
||||
elapsed_time=db_model.elapsed_time,
|
||||
# FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11.
|
||||
# However, this problem is not occurred in Python 3.12.
|
||||
#
|
||||
# A case of this error is:
|
||||
# https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24
|
||||
metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata),
|
||||
created_at=db_model.created_at,
|
||||
finished_at=db_model.finished_at,
|
||||
)
|
||||
|
||||
def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Convert a domain model to a database model.
|
||||
|
||||
Args:
|
||||
domain_model: The domain model to convert
|
||||
|
||||
Returns:
|
||||
The database model
|
||||
"""
|
||||
# Use values from constructor if provided
|
||||
if not self._triggered_from:
|
||||
raise ValueError("triggered_from is required in repository constructor")
|
||||
if not self._creator_user_id:
|
||||
raise ValueError("created_by is required in repository constructor")
|
||||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
db_model = WorkflowNodeExecution()
|
||||
db_model.id = domain_model.id
|
||||
db_model.tenant_id = self._tenant_id
|
||||
if self._app_id is not None:
|
||||
db_model.app_id = self._app_id
|
||||
db_model.workflow_id = domain_model.workflow_id
|
||||
db_model.triggered_from = self._triggered_from
|
||||
db_model.workflow_run_id = domain_model.workflow_run_id
|
||||
db_model.index = domain_model.index
|
||||
db_model.predecessor_node_id = domain_model.predecessor_node_id
|
||||
db_model.node_execution_id = domain_model.node_execution_id
|
||||
db_model.node_id = domain_model.node_id
|
||||
db_model.node_type = domain_model.node_type
|
||||
db_model.title = domain_model.title
|
||||
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
|
||||
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
|
||||
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
|
||||
db_model.status = domain_model.status
|
||||
db_model.error = domain_model.error
|
||||
db_model.elapsed_time = domain_model.elapsed_time
|
||||
db_model.execution_metadata = json.dumps(domain_model.metadata) if domain_model.metadata else None
|
||||
db_model.created_at = domain_model.created_at
|
||||
db_model.created_by_role = self._creator_user_role
|
||||
db_model.created_by = self._creator_user_id
|
||||
db_model.finished_at = domain_model.finished_at
|
||||
return db_model
|
||||
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
"""
|
||||
Save or update a NodeExecution domain entity to the database.
|
||||
|
||||
This method serves as a domain-to-database adapter that:
|
||||
1. Converts the domain entity to its database representation
|
||||
2. Persists the database model using SQLAlchemy's merge operation
|
||||
3. Maintains proper multi-tenancy by including tenant context during conversion
|
||||
4. Updates the in-memory cache for faster subsequent lookups
|
||||
|
||||
The method handles both creating new records and updating existing ones through
|
||||
SQLAlchemy's merge operation.
|
||||
|
||||
Args:
|
||||
execution: The NodeExecution domain entity to persist
|
||||
"""
|
||||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self.to_db_model(execution)
|
||||
|
||||
# Create a new database session
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.add(execution)
|
||||
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||
# based on the presence of the primary key
|
||||
session.merge(db_model)
|
||||
session.commit()
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
# Update the in-memory cache for faster subsequent lookups
|
||||
# Only cache if we have a node_execution_id to use as the cache key
|
||||
if db_model.node_execution_id:
|
||||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||
self._node_execution_cache[db_model.node_execution_id] = execution
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
First checks the in-memory cache, and if not found, queries the database.
|
||||
If found in the database, adds it to the cache for future lookups.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
The NodeExecution instance if found, None otherwise
|
||||
"""
|
||||
# First check the cache
|
||||
if node_execution_id in self._node_execution_cache:
|
||||
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
|
||||
return self._node_execution_cache[node_execution_id]
|
||||
|
||||
# If not in cache, query the database
|
||||
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
@ -85,15 +239,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalar(stmt)
|
||||
db_model = session.scalar(stmt)
|
||||
if db_model:
|
||||
# Convert to domain model
|
||||
domain_model = self._to_domain_model(db_model)
|
||||
|
||||
# Add to cache
|
||||
self._node_execution_cache[node_execution_id] = domain_model
|
||||
|
||||
return domain_model
|
||||
|
||||
return None
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
) -> Sequence[NodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
||||
This method always queries the database to ensure complete and ordered results,
|
||||
but updates the cache with any retrieved executions.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
@ -102,7 +269,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
A list of NodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
@ -129,17 +296,31 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
if order_columns:
|
||||
stmt = stmt.order_by(*order_columns)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
db_models = session.scalars(stmt).all()
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
# Convert database models to domain models and update cache
|
||||
domain_models = []
|
||||
for model in db_models:
|
||||
domain_model = self._to_domain_model(model)
|
||||
# Update cache if node_execution_id is present
|
||||
if domain_model.node_execution_id:
|
||||
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||
domain_models.append(domain_model)
|
||||
|
||||
return domain_models
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
This method queries the database directly and updates the cache with any
|
||||
retrieved executions that have a node_execution_id.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
A list of running NodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
@ -152,26 +333,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
db_models = session.scalars(stmt).all()
|
||||
domain_models = []
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
||||
for model in db_models:
|
||||
domain_model = self._to_domain_model(model)
|
||||
# Update cache if node_execution_id is present
|
||||
if domain_model.node_execution_id:
|
||||
self._node_execution_cache[domain_model.node_execution_id] = domain_model
|
||||
domain_models.append(domain_model)
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.merge(execution)
|
||||
session.commit()
|
||||
return domain_models
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
@ -179,6 +351,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||
and app_id (if provided) associated with this repository instance.
|
||||
It also clears the in-memory cache.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
@ -194,3 +367,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||
)
|
||||
|
||||
# Clear the in-memory cache
|
||||
self._node_execution_cache.clear()
|
||||
logger.info("Cleared in-memory node execution cache")
|
||||
|
@ -32,7 +32,7 @@ from core.tools.errors import (
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
@ -339,9 +339,9 @@ class ToolEngine:
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
CreatedByRole.ACCOUNT
|
||||
CreatorUserRole.ACCOUNT
|
||||
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -496,31 +495,6 @@ class ToolManager:
|
||||
# get plugin providers
|
||||
yield from cls.list_plugin_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
|
||||
"""
|
||||
list all the datasource providers
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_datasources(tenant_id)
|
||||
return [
|
||||
DatasourcePluginProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for provider in provider_entities
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]:
|
||||
"""
|
||||
list all the builtin datasources
|
||||
"""
|
||||
# get builtin datasources
|
||||
yield from cls.list_datasource_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
|
||||
"""
|
||||
|
7
api/core/variables/consts.py
Normal file
7
api/core/variables/consts.py
Normal file
@ -0,0 +1,7 @@
|
||||
# The minimal selector length for valid variables.
|
||||
#
|
||||
# The first element of the selector is the node id, and the second element is the variable name.
|
||||
#
|
||||
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
||||
# to extract part of the variable value.
|
||||
MIN_SELECTORS_LENGTH = 2
|
8
api/core/variables/utils.py
Normal file
8
api/core/variables/utils.py
Normal file
@ -0,0 +1,8 @@
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
selectors = [node_id, name]
|
||||
if paths:
|
||||
selectors.extend(paths)
|
||||
return selectors
|
98
api/core/workflow/entities/node_execution_entities.py
Normal file
98
api/core/workflow/entities/node_execution_entities.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Domain entities for workflow node execution.
|
||||
|
||||
This module contains the domain model for workflow node execution, which is used
|
||||
by the core workflow module. These models are independent of the storage mechanism
|
||||
and don't contain implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class NodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow node execution.
|
||||
|
||||
This model represents the core business entity of a node execution,
|
||||
without implementation details like tenant_id, app_id, etc.
|
||||
|
||||
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
|
||||
have been moved to the repository implementation to keep the domain model clean.
|
||||
These fields are still accepted in the constructor for backward compatibility,
|
||||
but they are not stored in the model.
|
||||
"""
|
||||
|
||||
# Core identification fields
|
||||
id: str # Unique identifier for this execution record
|
||||
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||
workflow_id: str # ID of the workflow this node belongs to
|
||||
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||
|
||||
# Execution positioning and flow
|
||||
index: int # Sequence number for ordering in trace visualization
|
||||
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
|
||||
node_id: str # ID of the node being executed
|
||||
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||
title: str # Display title of the node
|
||||
|
||||
# Execution data
|
||||
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
|
||||
error: Optional[str] = None # Error message if execution failed
|
||||
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||
|
||||
# Timing information
|
||||
created_at: datetime # When execution started
|
||||
finished_at: Optional[datetime] = None # When execution completed
|
||||
|
||||
def update_from_mapping(
|
||||
self,
|
||||
inputs: Optional[Mapping[str, Any]] = None,
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to update
|
||||
process_data: The process data to update
|
||||
outputs: The outputs to update
|
||||
metadata: The metadata to update
|
||||
"""
|
||||
if inputs is not None:
|
||||
self.inputs = dict(inputs)
|
||||
if process_data is not None:
|
||||
self.process_data = dict(process_data)
|
||||
if outputs is not None:
|
||||
self.outputs = dict(outputs)
|
||||
if metadata is not None:
|
||||
self.metadata = dict(metadata)
|
@ -1,35 +1,24 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.datasource.datasource_engine import DatasourceEngine
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
|
||||
from core.datasource.errors import DatasourceInvokeError
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceParameter,
|
||||
)
|
||||
from core.file import File
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
@ -49,7 +38,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
|
||||
# fetch datasource icon
|
||||
datasource_info = {
|
||||
"provider_type": node_data.provider_type.value,
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
@ -58,8 +46,10 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=node_data.provider_id,
|
||||
datasource_name=node_data.datasource_name,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -74,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
return
|
||||
|
||||
# get parameters
|
||||
datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
|
||||
datasource_parameters = datasource_runtime.entity.parameters
|
||||
parameters = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
@ -91,15 +81,20 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = DatasourceEngine.generic_invoke(
|
||||
datasource=datasource_runtime,
|
||||
datasource_parameters=parameters,
|
||||
# TODO: handle result
|
||||
result = datasource_runtime._invoke_second_step(
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
datasource_parameters=parameters,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except DatasourceNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -113,20 +108,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert datasource messages
|
||||
yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
|
||||
except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to transform datasource message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
@ -175,200 +156,6 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceInvokeMessage, None, None],
|
||||
datasource_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceInvokeMessage.MessageType.BINARY_LINK,
|
||||
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in NodeRunMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = datasource_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstallationManager()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
self.user_id,
|
||||
self.tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=self.id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
@ -3,17 +3,15 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: ToolProviderType
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
datasource_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
datasource_configurations: dict[str, Any]
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
|
@ -127,7 +127,7 @@ class GeneralStructureChunk(BaseModel):
|
||||
General Structure Chunk.
|
||||
"""
|
||||
|
||||
general_chunk: list[str]
|
||||
general_chunks: list[str]
|
||||
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
|
@ -2,12 +2,12 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderConfig:
|
||||
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||
"""Configuration for ordering NodeExecution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||
@ -15,10 +15,10 @@ class OrderConfig:
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
"""
|
||||
Repository interface for WorkflowNodeExecution.
|
||||
Repository interface for NodeExecution.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||
NodeExecution data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||
@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
def save(self, execution: NodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance.
|
||||
Save or update a NodeExecution instance.
|
||||
|
||||
This method handles both creating new records and updating existing ones.
|
||||
The implementation should determine whether to create or update based on
|
||||
the execution's ID or other identifying fields.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
execution: The NodeExecution instance to save or update
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
The NodeExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
) -> Sequence[NodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
Retrieve all NodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
A list of NodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
A list of running NodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||
Clear all NodeExecution records based on implementation-specific criteria.
|
||||
|
||||
This method is intended to be used for bulk deletion operations, such as removing
|
||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||
|
@ -58,7 +58,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
@ -94,11 +94,11 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if isinstance(user, EndUser):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatedByRole.END_USER
|
||||
self._created_by_role = CreatorUserRole.END_USER
|
||||
elif isinstance(user, Account):
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatedByRole.ACCOUNT
|
||||
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
|
@ -46,26 +46,28 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.node_execution_entities import (
|
||||
NodeExecution,
|
||||
NodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
|
||||
|
||||
@ -78,7 +80,6 @@ class WorkflowCycleManager:
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
@ -89,7 +90,7 @@ class WorkflowCycleManager:
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
created_by_role: CreatedByRole,
|
||||
created_by_role: CreatorUserRole,
|
||||
) -> WorkflowRun:
|
||||
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.scalar(workflow_stmt)
|
||||
@ -258,21 +259,22 @@ class WorkflowCycleManager:
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
# Use the instance repository to find running executions for a workflow run
|
||||
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||
running_domain_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||
workflow_run_id=workflow_run.id
|
||||
)
|
||||
|
||||
# Update the cache with the retrieved executions
|
||||
for execution in running_workflow_node_executions:
|
||||
if execution.node_execution_id:
|
||||
self._workflow_node_executions[execution.node_execution_id] = execution
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
# Update the domain models
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
for domain_execution in running_domain_executions:
|
||||
if domain_execution.node_execution_id:
|
||||
# Update the domain model
|
||||
domain_execution.status = NodeExecutionStatus.FAILED
|
||||
domain_execution.error = error
|
||||
domain_execution.finished_at = now
|
||||
domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds()
|
||||
|
||||
# Update the repository with the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -286,63 +288,67 @@ class WorkflowCycleManager:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_execution_start(
|
||||
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution:
|
||||
# Create a domain model
|
||||
created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
metadata = {
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = NodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=event.node_run_index,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_data.title,
|
||||
status=NodeExecutionStatus.RUNNING,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
# Use the instance repository to save the workflow node execution
|
||||
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||
# Use the instance repository to save the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
return domain_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
|
||||
# Get the domain model from repository
|
||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||
|
||||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata_dict = dict(event.execution_metadata or {})
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
# Update domain model
|
||||
domain_execution.status = NodeExecutionStatus.SUCCEEDED
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||
)
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = elapsed_time
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
# Update the repository with the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
# Use the instance repository to update the workflow node execution
|
||||
self._workflow_node_execution_repository.update(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
return domain_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self,
|
||||
@ -351,43 +357,52 @@ class WorkflowCycleManager:
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
) -> NodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||
# Get the domain model from repository
|
||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||
|
||||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
workflow_node_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
|
||||
# Update domain model
|
||||
domain_execution.status = (
|
||||
NodeExecutionStatus.FAILED
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
else NodeExecutionStatus.EXCEPTION
|
||||
)
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
domain_execution.error = event.error
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||
)
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = elapsed_time
|
||||
|
||||
self._workflow_node_execution_repository.update(workflow_node_execution)
|
||||
# Update the repository with the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
return domain_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
) -> NodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_run: workflow run
|
||||
@ -399,47 +414,47 @@ class WorkflowCycleManager:
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
# Convert metadata keys to strings
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
if event.execution_metadata is not None
|
||||
else origin_metadata
|
||||
|
||||
# Convert execution metadata keys to strings
|
||||
execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
|
||||
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||
|
||||
# Create a domain model
|
||||
domain_execution = NodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_data.title,
|
||||
status=NodeExecutionStatus.RETRY,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
elapsed_time=elapsed_time,
|
||||
error=event.error,
|
||||
index=event.node_run_index,
|
||||
)
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
# Update with mappings
|
||||
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
|
||||
|
||||
# Use the instance repository to save the workflow node execution
|
||||
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||
# Use the instance repository to save the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
return domain_execution
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self,
|
||||
@ -469,7 +484,7 @@ class WorkflowCycleManager:
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
@ -478,7 +493,7 @@ class WorkflowCycleManager:
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatedByRole.END_USER:
|
||||
elif workflow_run.created_by_role == CreatorUserRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
@ -515,9 +530,9 @@ class WorkflowCycleManager:
|
||||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
workflow_node_execution: NodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
@ -532,7 +547,7 @@ class WorkflowCycleManager:
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
@ -565,9 +580,9 @@ class WorkflowCycleManager:
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
workflow_node_execution: NodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
@ -584,16 +599,16 @@ class WorkflowCycleManager:
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
execution_metadata=workflow_node_execution.metadata,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
@ -608,9 +623,9 @@ class WorkflowCycleManager:
|
||||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
workflow_node_execution: NodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
@ -627,16 +642,16 @@ class WorkflowCycleManager:
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=workflow_node_execution.outputs,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
execution_metadata=workflow_node_execution.metadata,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
@ -908,23 +923,6 @@ class WorkflowCycleManager:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
# First check the cache for performance
|
||||
if node_execution_id in self._workflow_node_executions:
|
||||
cached_execution = self._workflow_node_executions[node_execution_id]
|
||||
# No need to merge with session since expire_on_commit=False
|
||||
return cached_execution
|
||||
|
||||
# If not in cache, use the instance repository to get by node_execution_id
|
||||
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
|
||||
|
||||
if not execution:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
# Update cache
|
||||
self._workflow_node_executions[node_execution_id] = execution
|
||||
return execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
Handle agent log
|
||||
|
73
api/extensions/ext_request_logging.py
Normal file
73
api/extensions/ext_request_logging.py
Normal file
@ -0,0 +1,73 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import flask
|
||||
import werkzeug.http
|
||||
from flask import Flask
|
||||
from flask.signals import request_finished, request_started
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_content_type_json(content_type: str) -> bool:
|
||||
if not content_type:
|
||||
return False
|
||||
content_type_no_option, _ = werkzeug.http.parse_options_header(content_type)
|
||||
return content_type_no_option.lower() == "application/json"
|
||||
|
||||
|
||||
def _log_request_started(_sender, **_extra):
|
||||
"""Log the start of a request."""
|
||||
if not _logger.isEnabledFor(logging.DEBUG):
|
||||
return
|
||||
|
||||
request = flask.request
|
||||
if not (_is_content_type_json(request.content_type) and request.data):
|
||||
_logger.debug("Received Request %s -> %s", request.method, request.path)
|
||||
return
|
||||
try:
|
||||
json_data = json.loads(request.data)
|
||||
except (TypeError, ValueError):
|
||||
_logger.exception("Failed to parse JSON request")
|
||||
return
|
||||
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
_logger.debug(
|
||||
"Received Request %s -> %s, Request Body:\n%s",
|
||||
request.method,
|
||||
request.path,
|
||||
formatted_json,
|
||||
)
|
||||
|
||||
|
||||
def _log_request_finished(_sender, response, **_extra):
|
||||
"""Log the end of a request."""
|
||||
if not _logger.isEnabledFor(logging.DEBUG) or response is None:
|
||||
return
|
||||
|
||||
if not _is_content_type_json(response.content_type):
|
||||
_logger.debug("Response %s %s", response.status, response.content_type)
|
||||
return
|
||||
|
||||
response_data = response.get_data(as_text=True)
|
||||
try:
|
||||
json_data = json.loads(response_data)
|
||||
except (TypeError, ValueError):
|
||||
_logger.exception("Failed to parse JSON response")
|
||||
return
|
||||
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
_logger.debug(
|
||||
"Response %s %s, Response Body:\n%s",
|
||||
response.status,
|
||||
response.content_type,
|
||||
formatted_json,
|
||||
)
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
"""Initialize the request logging extension."""
|
||||
if not dify_config.ENABLE_REQUEST_LOGGING:
|
||||
return
|
||||
request_started.connect(_log_request_started, app)
|
||||
request_finished.connect(_log_request_finished, app)
|
@ -80,9 +80,9 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va
|
||||
|
||||
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]])
|
||||
if not mapping.get("variable"):
|
||||
raise VariableError("missing variable")
|
||||
return mapping["variable"]
|
||||
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
@ -123,7 +123,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
|
@ -63,6 +63,7 @@ app_detail_fields = {
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
@ -98,6 +99,7 @@ app_partial_fields = {
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"access_mode": fields.String,
|
||||
}
|
||||
|
||||
|
||||
@ -176,6 +178,7 @@ app_detail_fields_with_site = {
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
|
||||
"access_mode": fields.String,
|
||||
}
|
||||
|
||||
|
||||
|
@ -153,6 +153,7 @@ pipeline_import_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"pipeline_id": fields.String,
|
||||
"dataset_id": fields.String,
|
||||
"current_dsl_version": fields.String,
|
||||
"imported_dsl_version": fields.String,
|
||||
"error": fields.String,
|
||||
|
@ -42,9 +42,19 @@ conversation_variable_fields = {
|
||||
|
||||
pipeline_variable_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"value_type": fields.String(attribute="value_type.value"),
|
||||
"value": fields.Raw,
|
||||
"label": fields.String,
|
||||
"variable": fields.String,
|
||||
"type": fields.String(attribute="type.value"),
|
||||
"belong_to_node_id": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"required": fields.Boolean,
|
||||
"default_value": fields.Raw,
|
||||
"options": fields.List(fields.String),
|
||||
"placeholder": fields.String,
|
||||
"tooltips": fields.String,
|
||||
"allowed_file_types": fields.List(fields.String),
|
||||
"allow_file_extension": fields.List(fields.String),
|
||||
"allow_file_upload_methods": fields.List(fields.String),
|
||||
}
|
||||
|
||||
workflow_fields = {
|
||||
@ -62,6 +72,7 @@ workflow_fields = {
|
||||
"tool_published": fields.Boolean,
|
||||
"environment_variables": fields.List(EnvironmentVariableField()),
|
||||
"conversation_variables": fields.List(fields.Nested(conversation_variable_fields)),
|
||||
"rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)),
|
||||
}
|
||||
|
||||
workflow_partial_fields = {
|
||||
|
@ -0,0 +1,51 @@
|
||||
"""add WorkflowDraftVariable model
|
||||
|
||||
Revision ID: 2adcbe1f5dfb
|
||||
Revises: d28f2004b072
|
||||
Create Date: 2025-05-15 15:31:03.128680
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2adcbe1f5dfb"
|
||||
down_revision = "d28f2004b072"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_draft_variables",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.String(length=255), nullable=False),
|
||||
sa.Column("selector", sa.String(length=255), nullable=False),
|
||||
sa.Column("value_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("visible", sa.Boolean(), nullable=False),
|
||||
sa.Column("editable", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
|
||||
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Dropping `workflow_draft_variables` also drops any index associated with it.
|
||||
op.drop_table("workflow_draft_variables")
|
||||
|
||||
# ### end Alembic commands ###
|
@ -27,7 +27,7 @@ from .dataset import (
|
||||
Whitelist,
|
||||
)
|
||||
from .engine import db
|
||||
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
|
||||
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
|
||||
from .model import (
|
||||
ApiRequest,
|
||||
ApiToken,
|
||||
@ -112,7 +112,7 @@ __all__ = [
|
||||
"CeleryTaskSet",
|
||||
"Conversation",
|
||||
"ConversationVariable",
|
||||
"CreatedByRole",
|
||||
"CreatorUserRole",
|
||||
"DataSourceApiKeyAuthBinding",
|
||||
"DataSourceOauthBinding",
|
||||
"Dataset",
|
||||
|
@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def pipeline(self):
|
||||
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
workflow_id = db.Column(StringUUID, nullable=True)
|
||||
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CreatedByRole(StrEnum):
|
||||
class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@ -14,3 +14,10 @@ class UserFrom(StrEnum):
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
|
||||
|
||||
class DraftVariableType(StrEnum):
|
||||
# node means that the correspond variable
|
||||
NODE = "node"
|
||||
SYS = "sys"
|
||||
CONVERSATION = "conversation"
|
||||
|
@ -29,7 +29,7 @@ from libs.helper import generate_string
|
||||
from .account import Account, Tenant
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatedByRole
|
||||
from .enums import CreatorUserRole
|
||||
from .types import StringUUID
|
||||
from .workflow import WorkflowRunStatus
|
||||
|
||||
@ -1270,7 +1270,7 @@ class MessageFile(Base):
|
||||
url: str | None = None,
|
||||
belongs_to: Literal["user", "assistant"] | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
created_by_role: CreatedByRole,
|
||||
created_by_role: CreatorUserRole,
|
||||
created_by: str,
|
||||
):
|
||||
self.message_id = message_id
|
||||
@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin):
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=True)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
external_user_id = db.Column(db.String(255), nullable=True)
|
||||
@ -1547,7 +1547,7 @@ class UploadFile(Base):
|
||||
size: int,
|
||||
extension: str,
|
||||
mime_type: str,
|
||||
created_by_role: CreatedByRole,
|
||||
created_by_role: CreatorUserRole,
|
||||
created_by: str,
|
||||
created_at: datetime,
|
||||
used: bool,
|
||||
|
@ -1,4 +1,7 @@
|
||||
from sqlalchemy import CHAR, TypeDecorator
|
||||
import enum
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
|
||||
@ -24,3 +27,51 @@ class StringUUID(TypeDecorator):
|
||||
if value is None:
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
_E = TypeVar("_E", bound=enum.StrEnum)
|
||||
|
||||
|
||||
class EnumText(TypeDecorator, Generic[_E]):
|
||||
impl = VARCHAR
|
||||
cache_ok = True
|
||||
|
||||
_length: int
|
||||
_enum_class: type[_E]
|
||||
|
||||
def __init__(self, enum_class: type[_E], length: int | None = None):
|
||||
self._enum_class = enum_class
|
||||
max_enum_value_len = max(len(e.value) for e in enum_class)
|
||||
if length is not None:
|
||||
if length < max_enum_value_len:
|
||||
raise ValueError("length should be greater than enum value length.")
|
||||
self._length = length
|
||||
else:
|
||||
# leave some rooms for future longer enum values.
|
||||
self._length = max(max_enum_value_len, 20)
|
||||
|
||||
def process_bind_param(self, value: _E | str | None, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, self._enum_class):
|
||||
return value.value
|
||||
elif isinstance(value, str):
|
||||
self._enum_class(value)
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
def process_result_value(self, value, dialect) -> _E | None:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"expected str, got {type(value)}")
|
||||
return self._enum_class(value)
|
||||
|
||||
def compare_values(self, x, y):
|
||||
if x is None or y is None:
|
||||
return x is y
|
||||
return x == y
|
||||
|
@ -1,29 +1,36 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from core.variables import utils as variable_utils
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
import contexts
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, Variable
|
||||
from core.variables import SecretVariable, Segment, SegmentType, Variable
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatedByRole
|
||||
from .types import StringUUID
|
||||
from .enums import CreatorUserRole, DraftVariableType
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
@ -331,6 +338,7 @@ class Workflow(Base):
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||
"rag_pipeline_variables": [var.model_dump(mode="json") for var in self.rag_pipeline_variables],
|
||||
}
|
||||
return result
|
||||
|
||||
@ -352,21 +360,19 @@ class Workflow(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
||||
def rag_pipeline_variables(self) -> Sequence[Variable]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = {}
|
||||
for k, v in variables_dict.items():
|
||||
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
|
||||
results = [v for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@pipeline_variables.setter
|
||||
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: List[dict]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
|
||||
{item["variable"]: item for item in values},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@ -452,15 +458,15 @@ class WorkflowRun(Base):
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
@property
|
||||
def graph_dict(self):
|
||||
@ -657,24 +663,24 @@ class WorkflowNodeExecution(Base):
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
@property
|
||||
def inputs_dict(self):
|
||||
return json.loads(self.inputs) if self.inputs else None
|
||||
|
||||
@property
|
||||
def outputs_dict(self):
|
||||
def outputs_dict(self) -> dict[str, Any] | None:
|
||||
return json.loads(self.outputs) if self.outputs else None
|
||||
|
||||
@property
|
||||
@ -682,7 +688,7 @@ class WorkflowNodeExecution(Base):
|
||||
return json.loads(self.process_data) if self.process_data else None
|
||||
|
||||
@property
|
||||
def execution_metadata_dict(self):
|
||||
def execution_metadata_dict(self) -> dict[str, Any] | None:
|
||||
return json.loads(self.execution_metadata) if self.execution_metadata else None
|
||||
|
||||
@property
|
||||
@ -778,15 +784,15 @@ class WorkflowAppLog(Base):
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from models.model import EndUser
|
||||
|
||||
created_by_role = CreatedByRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
|
||||
class ConversationVariable(Base):
|
||||
@ -820,3 +826,201 @@ class ConversationVariable(Base):
|
||||
def to_variable(self) -> Variable:
|
||||
mapping = json.loads(self.data)
|
||||
return variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
|
||||
|
||||
# Only `sys.query` and `sys.files` could be modified.
|
||||
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
|
||||
|
||||
|
||||
def _naive_utc_datetime():
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
class WorkflowDraftVariable(Base):
|
||||
@staticmethod
|
||||
def unique_columns() -> list[str]:
|
||||
return [
|
||||
"app_id",
|
||||
"node_id",
|
||||
"name",
|
||||
]
|
||||
|
||||
__tablename__ = "workflow_draft_variables"
|
||||
__table_args__ = (UniqueConstraint(*unique_columns()),)
|
||||
|
||||
# id is the unique identifier of a draft variable.
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
|
||||
created_at = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
# "`app_id` maps to the `id` field in the `model.App` model."
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# `last_edited_at` records when the value of a given draft variable
|
||||
# is edited.
|
||||
#
|
||||
# If it's not edited after creation, its value is `None`.
|
||||
last_edited_at: Mapped[datetime | None] = mapped_column(
|
||||
db.DateTime,
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
# The `node_id` field is special.
|
||||
#
|
||||
# If the variable is a conversation variable or a system variable, then the value of `node_id`
|
||||
# is `conversation` or `sys`, respective.
|
||||
#
|
||||
# Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is
|
||||
# the identity of correspond node in graph definition. An example of node id is `"1745769620734"`.
|
||||
#
|
||||
# However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other
|
||||
# "Answer" node conform the rules above.)
|
||||
node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id")
|
||||
|
||||
# From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than
|
||||
# 80 chars.
|
||||
#
|
||||
# ref: api/core/workflow/entities/variable_pool.py:18
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(
|
||||
sa.String(255),
|
||||
default="",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
|
||||
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
# JSON string
|
||||
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
|
||||
|
||||
# visible
|
||||
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
def get_selector(self) -> list[str]:
|
||||
selector = json.loads(self.selector)
|
||||
if not isinstance(selector, list):
|
||||
_logger.error(
|
||||
"invalid selector loaded from database, type=%s, value=%s",
|
||||
type(selector),
|
||||
self.selector,
|
||||
)
|
||||
raise ValueError("invalid selector.")
|
||||
return selector
|
||||
|
||||
def _set_selector(self, value: list[str]):
|
||||
self.selector = json.dumps(value)
|
||||
|
||||
def get_value(self) -> Segment | None:
|
||||
return build_segment(json.loads(self.value))
|
||||
|
||||
def set_name(self, name: str):
|
||||
self.name = name
|
||||
self._set_selector([self.node_id, name])
|
||||
|
||||
def set_value(self, value: Segment):
|
||||
self.value = json.dumps(value.value)
|
||||
self.value_type = value.value_type
|
||||
|
||||
def get_node_id(self) -> str | None:
|
||||
if self.get_variable_type() == DraftVariableType.NODE:
|
||||
return self.node_id
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variable_type(self) -> DraftVariableType:
|
||||
match self.node_id:
|
||||
case DraftVariableType.CONVERSATION:
|
||||
return DraftVariableType.CONVERSATION
|
||||
case DraftVariableType.SYS:
|
||||
return DraftVariableType.SYS
|
||||
case _:
|
||||
return DraftVariableType.NODE
|
||||
|
||||
@classmethod
|
||||
def _new(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
description: str = "",
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.created_at = _naive_utc_datetime()
|
||||
variable.updated_at = _naive_utc_datetime()
|
||||
variable.description = description
|
||||
variable.app_id = app_id
|
||||
variable.node_id = node_id
|
||||
variable.name = name
|
||||
variable.set_value(value)
|
||||
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def new_conversation_variable(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
value=value,
|
||||
)
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def new_sys_variable(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
editable: bool = False,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
|
||||
variable.editable = editable
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def new_node_variable(
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
visible: bool = True,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
|
||||
variable.visible = visible
|
||||
variable.editable = True
|
||||
return variable
|
||||
|
||||
@property
|
||||
def edited(self):
|
||||
return self.last_edited_at is not None
|
||||
|
||||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
@ -72,7 +72,7 @@ dependencies = [
|
||||
"python-dotenv==1.0.1",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=6.0.0",
|
||||
"redis[hiredis]~=6.1.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
|
@ -49,7 +49,7 @@ from services.errors.account import (
|
||||
RoleAlreadyAssignedError,
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.delete_account_task import delete_account_task
|
||||
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
|
||||
@ -628,6 +628,10 @@ class TenantService:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
|
||||
raise WorkSpaceNotAllowedCreateError()
|
||||
|
||||
workspaces = FeatureService.get_system_features().license.workspaces
|
||||
if not workspaces.is_available():
|
||||
raise WorkspacesLimitExceededError()
|
||||
|
||||
if name:
|
||||
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
||||
else:
|
||||
@ -937,7 +941,11 @@ class RegisterService:
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
|
||||
if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
|
||||
if (
|
||||
FeatureService.get_system_features().is_allow_create_workspace
|
||||
and create_workspace_required
|
||||
and FeatureService.get_system_features().license.workspaces.is_available()
|
||||
):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
|
@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.2.0"
|
||||
CURRENT_DSL_VERSION = "0.3.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
|
@ -18,8 +18,10 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.tag_service import TagService
|
||||
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
||||
|
||||
@ -155,6 +157,10 @@ class AppService:
|
||||
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
|
||||
|
||||
return app
|
||||
|
||||
def get_app(self, app: App) -> App:
|
||||
@ -307,6 +313,10 @@ class AppService:
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
# clean up web app settings
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
|
||||
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
@ -373,3 +383,15 @@ class AppService:
|
||||
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def get_app_code_by_id(app_id: str) -> str:
|
||||
"""
|
||||
Get app code by app id
|
||||
:param app_id: app id
|
||||
:return: app code
|
||||
"""
|
||||
site = db.session.query(Site).filter(Site.app_id == app_id).first()
|
||||
if not site:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
return str(site.code)
|
||||
|
@ -40,6 +40,7 @@ from models.dataset import (
|
||||
Document,
|
||||
DocumentSegment,
|
||||
ExternalKnowledgeBindings,
|
||||
Pipeline,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceOauthBinding
|
||||
@ -244,11 +245,24 @@ class DatasetService:
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first():
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(pipeline)
|
||||
db.session.flush()
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
@ -257,7 +271,8 @@ class DatasetService:
|
||||
provider="vendor",
|
||||
runtime_mode="rag_pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
created_by=current_user.id
|
||||
created_by=current_user.id,
|
||||
pipeline_id=pipeline.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
@ -269,7 +284,11 @@ class DatasetService:
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first():
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
@ -282,10 +301,13 @@ class DatasetService:
|
||||
runtime_mode="rag_pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
|
||||
if rag_pipeline_dataset_create_entity.yaml_content:
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline(
|
||||
current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
@ -1053,7 +1075,7 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
@ -1240,281 +1262,282 @@ class DocumentService:
|
||||
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def save_document_with_dataset_id(
|
||||
dataset: Dataset,
|
||||
knowledge_config: KnowledgeConfig,
|
||||
account: Account | Any,
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
# check document limit
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
# @staticmethod
|
||||
# def save_document_with_dataset_id(
|
||||
# dataset: Dataset,
|
||||
# knowledge_config: KnowledgeConfig,
|
||||
# account: Account | Any,
|
||||
# dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
# created_from: str = "web",
|
||||
# ):
|
||||
# # check document limit
|
||||
# features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
if not knowledge_config.original_document_id:
|
||||
count = 0
|
||||
if knowledge_config.data_source:
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list: # type: ignore
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
# if features.billing.enabled:
|
||||
# if not knowledge_config.original_document_id:
|
||||
# count = 0
|
||||
# if knowledge_config.data_source:
|
||||
# if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
# upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
# # type: ignore
|
||||
# count = len(upload_file_list)
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
# notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
# for notion_info in notion_info_list: # type: ignore
|
||||
# count = count + len(notion_info.pages)
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
# website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
# count = len(website_info.urls) # type: ignore
|
||||
# batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
# if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
# raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
# if count > batch_upload_limit:
|
||||
# raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
# DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
# # if dataset is empty, update dataset data_source_type
|
||||
# if not dataset.data_source_type:
|
||||
# dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
raise ValueError("Indexing technique is invalid")
|
||||
# if not dataset.indexing_technique:
|
||||
# if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
# raise ValueError("Indexing technique is invalid")
|
||||
|
||||
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
dataset_embedding_model = knowledge_config.embedding_model
|
||||
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
else:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_embedding_model = embedding_model.model
|
||||
dataset_embedding_model_provider = embedding_model.provider
|
||||
dataset.embedding_model = dataset_embedding_model
|
||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
dataset_embedding_model_provider, dataset_embedding_model
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
# dataset.indexing_technique = knowledge_config.indexing_technique
|
||||
# if knowledge_config.indexing_technique == "high_quality":
|
||||
# model_manager = ModelManager()
|
||||
# if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||
# dataset_embedding_model = knowledge_config.embedding_model
|
||||
# dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||
# else:
|
||||
# embedding_model = model_manager.get_default_model_instance(
|
||||
# tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
# )
|
||||
# dataset_embedding_model = embedding_model.model
|
||||
# dataset_embedding_model_provider = embedding_model.provider
|
||||
# dataset.embedding_model = dataset_embedding_model
|
||||
# dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
# dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
# dataset_embedding_model_provider, dataset_embedding_model
|
||||
# )
|
||||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
# "score_threshold_enabled": False,
|
||||
# }
|
||||
|
||||
dataset.retrieval_model = (
|
||||
knowledge_config.retrieval_model.model_dump()
|
||||
if knowledge_config.retrieval_model
|
||||
else default_retrieval_model
|
||||
) # type: ignore
|
||||
# dataset.retrieval_model = (
|
||||
# knowledge_config.retrieval_model.model_dump()
|
||||
# if knowledge_config.retrieval_model
|
||||
# else default_retrieval_model
|
||||
# ) # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
||||
documents.append(document)
|
||||
batch = document.batch
|
||||
else:
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
# documents = []
|
||||
# if knowledge_config.original_document_id:
|
||||
# document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
||||
# documents.append(document)
|
||||
# batch = document.batch
|
||||
# else:
|
||||
# batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
# # save process rule
|
||||
# if not dataset_process_rule:
|
||||
# process_rule = knowledge_config.process_rule
|
||||
# if process_rule:
|
||||
# if process_rule.mode in ("custom", "hierarchical"):
|
||||
# dataset_process_rule = DatasetProcessRule(
|
||||
# dataset_id=dataset.id,
|
||||
# mode=process_rule.mode,
|
||||
# rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
# created_by=account.id,
|
||||
# )
|
||||
# elif process_rule.mode == "automatic":
|
||||
# dataset_process_rule = DatasetProcessRule(
|
||||
# dataset_id=dataset.id,
|
||||
# mode=process_rule.mode,
|
||||
# rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
# created_by=account.id,
|
||||
# )
|
||||
# else:
|
||||
# logging.warn(
|
||||
# f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
# )
|
||||
# return
|
||||
# db.session.add(dataset_process_rule)
|
||||
# db.session.commit()
|
||||
# lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
# with redis_client.lock(lock_name, timeout=600):
|
||||
# position = DocumentService.get_documents_position(dataset.id)
|
||||
# document_ids = []
|
||||
# duplicate_document_ids = []
|
||||
# if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
# upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
# for file_id in upload_file_list:
|
||||
# file = (
|
||||
# db.session.query(UploadFile)
|
||||
# .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
# .first()
|
||||
# )
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
# # raise error if file not found
|
||||
# if not file:
|
||||
# raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
if page.page_id not in exist_page_ids:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
truncated_page_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
else:
|
||||
exist_document.pop(page.page_id)
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
if len(url) > 255:
|
||||
document_name = url[:200] + "..."
|
||||
else:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
document_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
# file_name = file.name
|
||||
# data_source_info = {
|
||||
# "upload_file_id": file_id,
|
||||
# }
|
||||
# # check duplicate
|
||||
# if knowledge_config.duplicate:
|
||||
# document = Document.query.filter_by(
|
||||
# dataset_id=dataset.id,
|
||||
# tenant_id=current_user.current_tenant_id,
|
||||
# data_source_type="upload_file",
|
||||
# enabled=True,
|
||||
# name=file_name,
|
||||
# ).first()
|
||||
# if document:
|
||||
# document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
# document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
# document.created_from = created_from
|
||||
# document.doc_form = knowledge_config.doc_form
|
||||
# document.doc_language = knowledge_config.doc_language
|
||||
# document.data_source_info = json.dumps(data_source_info)
|
||||
# document.batch = batch
|
||||
# document.indexing_status = "waiting"
|
||||
# db.session.add(document)
|
||||
# documents.append(document)
|
||||
# duplicate_document_ids.append(document.id)
|
||||
# continue
|
||||
# document = DocumentService.build_document(
|
||||
# dataset,
|
||||
# dataset_process_rule.id, # type: ignore
|
||||
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
# knowledge_config.doc_form,
|
||||
# knowledge_config.doc_language,
|
||||
# data_source_info,
|
||||
# created_from,
|
||||
# position,
|
||||
# account,
|
||||
# file_name,
|
||||
# batch,
|
||||
# )
|
||||
# db.session.add(document)
|
||||
# db.session.flush()
|
||||
# document_ids.append(document.id)
|
||||
# documents.append(document)
|
||||
# position += 1
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
# notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
# if not notion_info_list:
|
||||
# raise ValueError("No notion info list found.")
|
||||
# exist_page_ids = []
|
||||
# exist_document = {}
|
||||
# documents = Document.query.filter_by(
|
||||
# dataset_id=dataset.id,
|
||||
# tenant_id=current_user.current_tenant_id,
|
||||
# data_source_type="notion_import",
|
||||
# enabled=True,
|
||||
# ).all()
|
||||
# if documents:
|
||||
# for document in documents:
|
||||
# data_source_info = json.loads(document.data_source_info)
|
||||
# exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
# for notion_info in notion_info_list:
|
||||
# workspace_id = notion_info.workspace_id
|
||||
# data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
# db.and_(
|
||||
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
# DataSourceOauthBinding.provider == "notion",
|
||||
# DataSourceOauthBinding.disabled == False,
|
||||
# DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
# )
|
||||
# ).first()
|
||||
# if not data_source_binding:
|
||||
# raise ValueError("Data source binding not found.")
|
||||
# for page in notion_info.pages:
|
||||
# if page.page_id not in exist_page_ids:
|
||||
# data_source_info = {
|
||||
# "notion_workspace_id": workspace_id,
|
||||
# "notion_page_id": page.page_id,
|
||||
# "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||
# "type": page.type,
|
||||
# }
|
||||
# # Truncate page name to 255 characters to prevent DB field length errors
|
||||
# truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
# document = DocumentService.build_document(
|
||||
# dataset,
|
||||
# dataset_process_rule.id, # type: ignore
|
||||
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
# knowledge_config.doc_form,
|
||||
# knowledge_config.doc_language,
|
||||
# data_source_info,
|
||||
# created_from,
|
||||
# position,
|
||||
# account,
|
||||
# truncated_page_name,
|
||||
# batch,
|
||||
# )
|
||||
# db.session.add(document)
|
||||
# db.session.flush()
|
||||
# document_ids.append(document.id)
|
||||
# documents.append(document)
|
||||
# position += 1
|
||||
# else:
|
||||
# exist_document.pop(page.page_id)
|
||||
# # delete not selected documents
|
||||
# if len(exist_document) > 0:
|
||||
# clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
# elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
# website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
# if not website_info:
|
||||
# raise ValueError("No website info list found.")
|
||||
# urls = website_info.urls
|
||||
# for url in urls:
|
||||
# data_source_info = {
|
||||
# "url": url,
|
||||
# "provider": website_info.provider,
|
||||
# "job_id": website_info.job_id,
|
||||
# "only_main_content": website_info.only_main_content,
|
||||
# "mode": "crawl",
|
||||
# }
|
||||
# if len(url) > 255:
|
||||
# document_name = url[:200] + "..."
|
||||
# else:
|
||||
# document_name = url
|
||||
# document = DocumentService.build_document(
|
||||
# dataset,
|
||||
# dataset_process_rule.id, # type: ignore
|
||||
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
# knowledge_config.doc_form,
|
||||
# knowledge_config.doc_language,
|
||||
# data_source_info,
|
||||
# created_from,
|
||||
# position,
|
||||
# account,
|
||||
# document_name,
|
||||
# batch,
|
||||
# )
|
||||
# db.session.add(document)
|
||||
# db.session.flush()
|
||||
# document_ids.append(document.id)
|
||||
# documents.append(document)
|
||||
# position += 1
|
||||
# db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
if document_ids:
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
# # trigger async task
|
||||
# if document_ids:
|
||||
# document_indexing_task.delay(dataset.id, document_ids)
|
||||
# if duplicate_document_ids:
|
||||
# duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
return documents, batch
|
||||
# return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user