mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-04 17:50:41 +08:00
merge main
This commit is contained in:
commit
98e44ca2cc
12
.github/workflows/api-tests.yml
vendored
12
.github/workflows/api-tests.yml
vendored
@ -45,7 +45,17 @@ jobs:
|
|||||||
run: uv sync --project api --dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Run Unit tests
|
- name: Run Unit tests
|
||||||
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
run: |
|
||||||
|
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||||
|
# Extract coverage percentage and create a summary
|
||||||
|
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||||
|
|
||||||
|
# Create a detailed coverage summary
|
||||||
|
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||||
|
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
- name: Run dify config tests
|
- name: Run dify config tests
|
||||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -46,6 +46,7 @@ htmlcov/
|
|||||||
.cache
|
.cache
|
||||||
nosetests.xml
|
nosetests.xml
|
||||||
coverage.xml
|
coverage.xml
|
||||||
|
coverage.json
|
||||||
*.cover
|
*.cover
|
||||||
*.py,cover
|
*.py,cover
|
||||||
.hypothesis/
|
.hypothesis/
|
||||||
|
@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530
|
|||||||
MILVUS_TOKEN=
|
MILVUS_TOKEN=
|
||||||
MILVUS_USER=root
|
MILVUS_USER=root
|
||||||
MILVUS_PASSWORD=Milvus
|
MILVUS_PASSWORD=Milvus
|
||||||
|
MILVUS_ANALYZER_PARAMS=
|
||||||
|
|
||||||
# MyScale configuration
|
# MyScale configuration
|
||||||
MYSCALE_HOST=127.0.0.1
|
MYSCALE_HOST=127.0.0.1
|
||||||
@ -423,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
|||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||||
MAX_VARIABLE_SIZE=204800
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
|
# Workflow storage configuration
|
||||||
|
# Options: rdbms, hybrid
|
||||||
|
# rdbms: Use only the relational database (default)
|
||||||
|
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||||
|
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||||
|
|
||||||
# App configuration
|
# App configuration
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
APP_MAX_ACTIVE_REQUESTS=0
|
APP_MAX_ACTIVE_REQUESTS=0
|
||||||
|
@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp):
|
|||||||
ext_otel,
|
ext_otel,
|
||||||
ext_proxy_fix,
|
ext_proxy_fix,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
|
ext_repositories,
|
||||||
ext_sentry,
|
ext_sentry,
|
||||||
ext_set_secretkey,
|
ext_set_secretkey,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp):
|
|||||||
ext_migrate,
|
ext_migrate,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
|
ext_repositories,
|
||||||
ext_celery,
|
ext_celery,
|
||||||
ext_login,
|
ext_login,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
from configs.feature.hosted_service import HostedServiceConfig
|
from .hosted_service import HostedServiceConfig
|
||||||
|
|
||||||
|
|
||||||
class SecurityConfig(BaseSettings):
|
class SecurityConfig(BaseSettings):
|
||||||
@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
|
|||||||
default=100,
|
default=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
|
||||||
|
default="rdbms",
|
||||||
|
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(BaseSettings):
|
class AuthConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings):
|
|||||||
"older versions",
|
"older versions",
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
|
||||||
|
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
@ -4,14 +4,10 @@ import platform
|
|||||||
import re
|
import re
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import magic
|
import magic
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -31,8 +27,6 @@ except ImportError:
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
|
|
||||||
|
|
||||||
class FileInfo(BaseModel):
|
class FileInfo(BaseModel):
|
||||||
filename: str
|
filename: str
|
||||||
@ -89,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
|
|||||||
mimetype=mimetype,
|
mimetype=mimetype,
|
||||||
size=int(response.headers.get("Content-Length", -1)),
|
size=int(response.headers.get("Content-Length", -1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
|
|
||||||
return {
|
|
||||||
"opening_statement": features_dict.get("opening_statement"),
|
|
||||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
|
||||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
|
||||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
|
||||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
|
||||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
|
||||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
|
||||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
|
||||||
"user_input_form": user_input_form,
|
|
||||||
"sensitive_word_avoidance": features_dict.get(
|
|
||||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
|
||||||
),
|
|
||||||
"file_upload": features_dict.get(
|
|
||||||
"file_upload",
|
|
||||||
{
|
|
||||||
"image": {
|
|
||||||
"enabled": False,
|
|
||||||
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
|
|
||||||
"detail": "high",
|
|
||||||
"transfer_methods": ["remote_url", "local_file"],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"system_parameters": {
|
|
||||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
|
||||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
|
||||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
|
||||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
|
||||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
return code_result
|
return code_result
|
||||||
|
|
||||||
|
|
||||||
|
class RuleStructuredOutputGenerateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
account = current_user
|
||||||
|
try:
|
||||||
|
structured_output = LLMGenerator.generate_structured_output(
|
||||||
|
tenant_id=account.current_tenant_id,
|
||||||
|
instruction=args["instruction"],
|
||||||
|
model_config=args["model_config"],
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
|
||||||
|
return structured_output
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||||
|
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from datetime import datetime
|
from dateutil.parser import isoparse
|
||||||
|
|
||||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||||
from flask_restful.inputs import int_range # type: ignore
|
from flask_restful.inputs import int_range # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
|
|||||||
|
|
||||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||||
if args.created_at__before:
|
if args.created_at__before:
|
||||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
args.created_at__before = isoparse(args.created_at__before)
|
||||||
|
|
||||||
if args.created_at__after:
|
if args.created_at__after:
|
||||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
args.created_at__after = isoparse(args.created_at__after)
|
||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
|
@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
|
|||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {"error": "Invalid provider"}, 400
|
return {"error": "Invalid provider"}, 400
|
||||||
if "code" in request.args:
|
if "code" in request.args:
|
||||||
code = request.args.get("code")
|
code = request.args.get("code", "")
|
||||||
|
if not code:
|
||||||
|
return {"error": "Invalid code"}, 400
|
||||||
try:
|
try:
|
||||||
oauth_provider.get_access_token(code)
|
oauth_provider.get_access_token(code)
|
||||||
except requests.exceptions.HTTPError as e:
|
except requests.exceptions.HTTPError as e:
|
||||||
|
@ -16,7 +16,7 @@ from controllers.console.auth.error import (
|
|||||||
PasswordMismatchError,
|
PasswordMismatchError,
|
||||||
)
|
)
|
||||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
@ -30,6 +30,7 @@ from services.feature_service import FeatureService
|
|||||||
|
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
|
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
parser.add_argument("email", type=str, required=True, location="json")
|
||||||
@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
|
# Verified, revoke the first token
|
||||||
|
AccountService.revoke_reset_password_token(args["token"])
|
||||||
|
|
||||||
|
# Refresh token data by generating a new token
|
||||||
|
_, new_token = AccountService.generate_reset_password_token(
|
||||||
|
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||||
|
)
|
||||||
|
|
||||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||||
return {"is_valid": True, "email": token_data.get("email")}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||||
if not reset_data:
|
if not reset_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
# Must use token in reset phase
|
||||||
|
if reset_data.get("phase", "") != "reset":
|
||||||
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(args["token"])
|
||||||
|
@ -22,7 +22,7 @@ from controllers.console.error import (
|
|||||||
EmailSendIpLimitError,
|
EmailSendIpLimitError,
|
||||||
NotAllowedCreateWorkspace,
|
NotAllowedCreateWorkspace,
|
||||||
)
|
)
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
@ -38,6 +38,7 @@ class LoginApi(Resource):
|
|||||||
"""Resource for user login."""
|
"""Resource for user login."""
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Authenticate user and login."""
|
"""Authenticate user and login."""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -110,6 +111,7 @@ class LogoutApi(Resource):
|
|||||||
|
|
||||||
class ResetPasswordSendEmailApi(Resource):
|
class ResetPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from flask_restful import marshal_with # type: ignore
|
from flask_restful import marshal_with # type: ignore
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.common import helpers as controller_helpers
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import AppUnavailableError
|
from controllers.console.app.error import AppUnavailableError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
|
|||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return controller_helpers.get_parameters_from_feature_dict(
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
features_dict=features_dict, user_input_form=user_input_form
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExploreAppMetaApi(InstalledAppResource):
|
class ExploreAppMetaApi(InstalledAppResource):
|
||||||
|
@ -210,3 +210,16 @@ def enterprise_license_required(view):
|
|||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def email_password_login_enabled(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
features = FeatureService.get_system_features()
|
||||||
|
if features.enable_email_password_login:
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
# otherwise, return 403
|
||||||
|
abort(403)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
|
|||||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||||
from core.plugin.entities.request import (
|
from core.plugin.entities.request import (
|
||||||
|
RequestFetchAppInfo,
|
||||||
RequestInvokeApp,
|
RequestInvokeApp,
|
||||||
RequestInvokeEncrypt,
|
RequestInvokeEncrypt,
|
||||||
RequestInvokeLLM,
|
RequestInvokeLLM,
|
||||||
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
|
|||||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchAppInfoApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestFetchAppInfo)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
|
||||||
|
return BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||||
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
|||||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||||
|
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from flask_restful import Resource, marshal_with # type: ignore
|
from flask_restful import Resource, marshal_with # type: ignore
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.common import helpers as controller_helpers
|
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import AppUnavailableError
|
from controllers.service_api.app.error import AppUnavailableError
|
||||||
from controllers.service_api.wraps import validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
|
|||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return controller_helpers.get_parameters_from_feature_dict(
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
features_dict=features_dict, user_input_form=user_input_form
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AppMetaApi(Resource):
|
class AppMetaApi(Resource):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
from dateutil.parser import isoparse
|
||||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||||
from flask_restful.inputs import int_range # type: ignore
|
from flask_restful.inputs import int_range # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
|
|||||||
|
|
||||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||||
if args.created_at__before:
|
if args.created_at__before:
|
||||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
args.created_at__before = isoparse(args.created_at__before)
|
||||||
|
|
||||||
if args.created_at__after:
|
if args.created_at__after:
|
||||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
args.created_at__after = isoparse(args.created_at__after)
|
||||||
|
|
||||||
# get paginate workflow app logs
|
# get paginate workflow app logs
|
||||||
workflow_app_service = WorkflowAppService()
|
workflow_app_service = WorkflowAppService()
|
||||||
|
@ -139,7 +139,9 @@ class DatasetListApi(DatasetApiResource):
|
|||||||
external_knowledge_id=args["external_knowledge_id"],
|
external_knowledge_id=args["external_knowledge_id"],
|
||||||
embedding_model_provider=args["embedding_model_provider"],
|
embedding_model_provider=args["embedding_model_provider"],
|
||||||
embedding_model_name=args["embedding_model"],
|
embedding_model_name=args["embedding_model"],
|
||||||
retrieval_model=RetrievalModel(**args["retrieval_model"]),
|
retrieval_model=RetrievalModel(**args["retrieval_model"])
|
||||||
|
if args["retrieval_model"] is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
|
@ -122,6 +122,8 @@ class SegmentApi(DatasetApiResource):
|
|||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
status_list=args["status"],
|
status_list=args["status"],
|
||||||
keyword=args["keyword"],
|
keyword=args["keyword"],
|
||||||
|
page=page,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from flask_restful import marshal_with # type: ignore
|
from flask_restful import marshal_with # type: ignore
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.common import helpers as controller_helpers
|
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import AppUnavailableError
|
from controllers.web.error import AppUnavailableError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from models.model import App, AppMode
|
from models.model import App, AppMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
|
|||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return controller_helpers.get_parameters_from_feature_dict(
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
features_dict=features_dict, user_input_form=user_input_form
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AppMeta(WebApiResource):
|
class AppMeta(WebApiResource):
|
||||||
|
@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
|
|||||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||||
|
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||||
"status": fields.String,
|
"status": fields.String,
|
||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
|
|||||||
return cast_parameter_value(self, value)
|
return cast_parameter_value(self, value)
|
||||||
|
|
||||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||||
|
help: Optional[I18nObject] = None
|
||||||
|
|
||||||
def init_frontend_parameter(self, value: Any):
|
def init_frontend_parameter(self, value: Any):
|
||||||
return init_frontend_parameter(self, self.type, value)
|
return init_frontend_parameter(self, self.type, value)
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameters_from_feature_dict(
|
||||||
|
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Mapping from feature dict to webapp parameters
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"opening_statement": features_dict.get("opening_statement"),
|
||||||
|
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||||
|
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||||
|
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||||
|
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||||
|
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||||
|
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||||
|
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||||
|
"user_input_form": user_input_form,
|
||||||
|
"sensitive_word_avoidance": features_dict.get(
|
||||||
|
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||||
|
),
|
||||||
|
"file_upload": features_dict.get(
|
||||||
|
"file_upload",
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"enabled": False,
|
||||||
|
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
|
||||||
|
"detail": "high",
|
||||||
|
"transfer_methods": ["remote_url", "local_file"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||||
|
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
|
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
|
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||||
|
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||||
|
},
|
||||||
|
}
|
@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
session=session, event=event
|
event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
| QueueNodeExceptionEvent,
|
| QueueNodeExceptionEvent,
|
||||||
):
|
):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
session=session, event=event
|
event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if node_finish_resp:
|
if node_finish_resp:
|
||||||
yield node_finish_resp
|
yield node_finish_resp
|
||||||
|
@ -17,6 +17,7 @@ class BaseAppGenerator:
|
|||||||
user_inputs: Optional[Mapping[str, Any]],
|
user_inputs: Optional[Mapping[str, Any]],
|
||||||
variables: Sequence["VariableEntity"],
|
variables: Sequence["VariableEntity"],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
@ -37,6 +38,7 @@ class BaseAppGenerator:
|
|||||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
),
|
),
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
for k, v in user_inputs.items()
|
for k, v in user_inputs.items()
|
||||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||||
|
@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
file_upload_config=file_extra_config,
|
file_upload_config=file_extra_config,
|
||||||
inputs=self._prepare_user_inputs(
|
inputs=self._prepare_user_inputs(
|
||||||
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
|
user_inputs=inputs,
|
||||||
|
variables=app_config.variables,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
),
|
),
|
||||||
files=list(system_files),
|
files=list(system_files),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
session=session, workflow_run_id=self._workflow_run_id
|
session=session, workflow_run_id=self._workflow_run_id
|
||||||
)
|
)
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||||
session=session, workflow_run=workflow_run, event=event
|
workflow_run=workflow_run, event=event
|
||||||
)
|
)
|
||||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
if node_start_response:
|
if node_start_response:
|
||||||
yield node_start_response
|
yield node_start_response
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
session=session, event=event
|
event=event
|
||||||
)
|
)
|
||||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if node_success_response:
|
if node_success_response:
|
||||||
yield node_success_response
|
yield node_success_response
|
||||||
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
| QueueNodeExceptionEvent,
|
| QueueNodeExceptionEvent,
|
||||||
):
|
):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
|
||||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
)
|
)
|
||||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||||
session=session,
|
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if node_failed_response:
|
if node_failed_response:
|
||||||
yield node_failed_response
|
yield node_failed_response
|
||||||
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
workflow_app_log.created_by = self._user_id
|
workflow_app_log.created_by = self._user_id
|
||||||
|
|
||||||
session.add(workflow_app_log)
|
session.add(workflow_app_log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
def _text_chunk_to_stream_response(
|
def _text_chunk_to_stream_response(
|
||||||
self, text: str, from_variable_selector: Optional[list[str]] = None
|
self, text: str, from_variable_selector: Optional[list[str]] = None
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
@ -80,6 +82,21 @@ class WorkflowCycleManage:
|
|||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_system_variables = workflow_system_variables
|
self._workflow_system_variables = workflow_system_variables
|
||||||
|
|
||||||
|
# Initialize the session factory and repository
|
||||||
|
# We use the global db engine instead of the session passed to methods
|
||||||
|
# Disable expire_on_commit to avoid the need for merging objects
|
||||||
|
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
params={
|
||||||
|
"tenant_id": self._application_generate_entity.app_config.tenant_id,
|
||||||
|
"app_id": self._application_generate_entity.app_config.app_id,
|
||||||
|
"session_factory": self._session_factory,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll still keep the cache for backward compatibility and performance
|
||||||
|
# but use the repository for database operations
|
||||||
|
|
||||||
def _handle_workflow_run_start(
|
def _handle_workflow_run_start(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -254,19 +271,15 @@ class WorkflowCycleManage:
|
|||||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_run.exceptions_count = exceptions_count
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
|
||||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
# Use the instance repository to find running executions for a workflow run
|
||||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
workflow_run_id=workflow_run.id
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
|
||||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
|
||||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
|
||||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
|
||||||
)
|
)
|
||||||
ids = session.scalars(stmt).all()
|
|
||||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
# Update the cache with the retrieved executions
|
||||||
running_workflow_node_executions = [
|
for execution in running_workflow_node_executions:
|
||||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
if execution.node_execution_id:
|
||||||
]
|
self._workflow_node_executions[execution.node_execution_id] = execution
|
||||||
|
|
||||||
for workflow_node_execution in running_workflow_node_executions:
|
for workflow_node_execution in running_workflow_node_executions:
|
||||||
now = datetime.now(UTC).replace(tzinfo=None)
|
now = datetime.now(UTC).replace(tzinfo=None)
|
||||||
@ -288,7 +301,7 @@ class WorkflowCycleManage:
|
|||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _handle_node_execution_start(
|
def _handle_node_execution_start(
|
||||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||||
) -> WorkflowNodeExecution:
|
) -> WorkflowNodeExecution:
|
||||||
workflow_node_execution = WorkflowNodeExecution()
|
workflow_node_execution = WorkflowNodeExecution()
|
||||||
workflow_node_execution.id = str(uuid4())
|
workflow_node_execution.id = str(uuid4())
|
||||||
@ -315,17 +328,14 @@ class WorkflowCycleManage:
|
|||||||
)
|
)
|
||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
# Use the instance repository to save the workflow node execution
|
||||||
|
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||||
|
|
||||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_success(
|
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||||
self, *, session: Session, event: QueueNodeSucceededEvent
|
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||||
) -> WorkflowNodeExecution:
|
|
||||||
workflow_node_execution = self._get_workflow_node_execution(
|
|
||||||
session=session, node_execution_id=event.node_execution_id
|
|
||||||
)
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
@ -344,13 +354,13 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
|
||||||
workflow_node_execution = session.merge(workflow_node_execution)
|
# Use the instance repository to update the workflow node execution
|
||||||
|
self._workflow_node_execution_repository.update(workflow_node_execution)
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeFailedEvent
|
event: QueueNodeFailedEvent
|
||||||
| QueueNodeInIterationFailedEvent
|
| QueueNodeInIterationFailedEvent
|
||||||
| QueueNodeInLoopFailedEvent
|
| QueueNodeInLoopFailedEvent
|
||||||
@ -361,9 +371,7 @@ class WorkflowCycleManage:
|
|||||||
:param event: queue node failed event
|
:param event: queue node failed event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
workflow_node_execution = self._get_workflow_node_execution(
|
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
|
||||||
session=session, node_execution_id=event.node_execution_id
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
@ -387,14 +395,14 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
|
|
||||||
workflow_node_execution = session.merge(workflow_node_execution)
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_retried(
|
def _handle_workflow_node_execution_retried(
|
||||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||||
) -> WorkflowNodeExecution:
|
) -> WorkflowNodeExecution:
|
||||||
"""
|
"""
|
||||||
Workflow node execution failed
|
Workflow node execution failed
|
||||||
|
:param workflow_run: workflow run
|
||||||
:param event: queue node failed event
|
:param event: queue node failed event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -439,15 +447,12 @@ class WorkflowCycleManage:
|
|||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
workflow_node_execution.index = event.node_run_index
|
workflow_node_execution.index = event.node_run_index
|
||||||
|
|
||||||
session.add(workflow_node_execution)
|
# Use the instance repository to save the workflow node execution
|
||||||
|
self._workflow_node_execution_repository.save(workflow_node_execution)
|
||||||
|
|
||||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
#################################################
|
|
||||||
# to stream responses #
|
|
||||||
#################################################
|
|
||||||
|
|
||||||
def _workflow_start_to_stream_response(
|
def _workflow_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -455,7 +460,6 @@ class WorkflowCycleManage:
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_run: WorkflowRun,
|
workflow_run: WorkflowRun,
|
||||||
) -> WorkflowStartStreamResponse:
|
) -> WorkflowStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return WorkflowStartStreamResponse(
|
return WorkflowStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -521,14 +525,10 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_node_start_to_stream_response(
|
def _workflow_node_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeStartedEvent,
|
event: QueueNodeStartedEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[NodeStartStreamResponse]:
|
) -> Optional[NodeStartStreamResponse]:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
|
||||||
|
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
@ -571,7 +571,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_node_finish_to_stream_response(
|
def _workflow_node_finish_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeSucceededEvent
|
event: QueueNodeSucceededEvent
|
||||||
| QueueNodeFailedEvent
|
| QueueNodeFailedEvent
|
||||||
| QueueNodeInIterationFailedEvent
|
| QueueNodeInIterationFailedEvent
|
||||||
@ -580,8 +579,6 @@ class WorkflowCycleManage:
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[NodeFinishStreamResponse]:
|
) -> Optional[NodeFinishStreamResponse]:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
@ -621,13 +618,10 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_node_retry_to_stream_response(
|
def _workflow_node_retry_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session: Session,
|
|
||||||
event: QueueNodeRetryEvent,
|
event: QueueNodeRetryEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
|
||||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
return None
|
return None
|
||||||
if not workflow_node_execution.workflow_run_id:
|
if not workflow_node_execution.workflow_run_id:
|
||||||
@ -668,7 +662,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_parallel_branch_start_to_stream_response(
|
def _workflow_parallel_branch_start_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||||
) -> ParallelBranchStartStreamResponse:
|
) -> ParallelBranchStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return ParallelBranchStartStreamResponse(
|
return ParallelBranchStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -692,7 +685,6 @@ class WorkflowCycleManage:
|
|||||||
workflow_run: WorkflowRun,
|
workflow_run: WorkflowRun,
|
||||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||||
) -> ParallelBranchFinishedStreamResponse:
|
) -> ParallelBranchFinishedStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return ParallelBranchFinishedStreamResponse(
|
return ParallelBranchFinishedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -713,7 +705,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_iteration_start_to_stream_response(
|
def _workflow_iteration_start_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||||
) -> IterationNodeStartStreamResponse:
|
) -> IterationNodeStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return IterationNodeStartStreamResponse(
|
return IterationNodeStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -735,7 +726,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_iteration_next_to_stream_response(
|
def _workflow_iteration_next_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||||
) -> IterationNodeNextStreamResponse:
|
) -> IterationNodeNextStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return IterationNodeNextStreamResponse(
|
return IterationNodeNextStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -759,7 +749,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_iteration_completed_to_stream_response(
|
def _workflow_iteration_completed_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||||
) -> IterationNodeCompletedStreamResponse:
|
) -> IterationNodeCompletedStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return IterationNodeCompletedStreamResponse(
|
return IterationNodeCompletedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -790,7 +779,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_loop_start_to_stream_response(
|
def _workflow_loop_start_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
||||||
) -> LoopNodeStartStreamResponse:
|
) -> LoopNodeStartStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return LoopNodeStartStreamResponse(
|
return LoopNodeStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -812,7 +800,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_loop_next_to_stream_response(
|
def _workflow_loop_next_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
||||||
) -> LoopNodeNextStreamResponse:
|
) -> LoopNodeNextStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return LoopNodeNextStreamResponse(
|
return LoopNodeNextStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -836,7 +823,6 @@ class WorkflowCycleManage:
|
|||||||
def _workflow_loop_completed_to_stream_response(
|
def _workflow_loop_completed_to_stream_response(
|
||||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
||||||
) -> LoopNodeCompletedStreamResponse:
|
) -> LoopNodeCompletedStreamResponse:
|
||||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
|
||||||
_ = session
|
_ = session
|
||||||
return LoopNodeCompletedStreamResponse(
|
return LoopNodeCompletedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -934,11 +920,22 @@ class WorkflowCycleManage:
|
|||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||||
if node_execution_id not in self._workflow_node_executions:
|
# First check the cache for performance
|
||||||
|
if node_execution_id in self._workflow_node_executions:
|
||||||
|
cached_execution = self._workflow_node_executions[node_execution_id]
|
||||||
|
# No need to merge with session since expire_on_commit=False
|
||||||
|
return cached_execution
|
||||||
|
|
||||||
|
# If not in cache, use the instance repository to get by node_execution_id
|
||||||
|
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
|
||||||
return session.merge(cached_workflow_node_execution)
|
# Update cache
|
||||||
|
self._workflow_node_executions[node_execution_id] = execution
|
||||||
|
return execution
|
||||||
|
|
||||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||||
"""
|
"""
|
||||||
|
@ -6,7 +6,6 @@ from core.rag.models.document import Document
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import DatasetRetrieverResource
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIndexToolCallbackHandler:
|
class DatasetIndexToolCallbackHandler:
|
||||||
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
|
|
||||||
def return_retriever_resource_info(self, resource: list):
|
def return_retriever_resource_info(self, resource: list):
|
||||||
"""Handle return_retriever_resource_info."""
|
"""Handle return_retriever_resource_info."""
|
||||||
if resource and len(resource) > 0:
|
|
||||||
for item in resource:
|
|
||||||
dataset_retriever_resource = DatasetRetrieverResource(
|
|
||||||
message_id=self._message_id,
|
|
||||||
position=item.get("position") or 0,
|
|
||||||
dataset_id=item.get("dataset_id"),
|
|
||||||
dataset_name=item.get("dataset_name"),
|
|
||||||
document_id=item.get("document_id"),
|
|
||||||
document_name=item.get("document_name"),
|
|
||||||
data_source_type=item.get("data_source_type"),
|
|
||||||
segment_id=item.get("segment_id"),
|
|
||||||
score=item.get("score") if "score" in item else None,
|
|
||||||
hit_count=item.get("hit_count") if "hit_count" in item else None,
|
|
||||||
word_count=item.get("word_count") if "word_count" in item else None,
|
|
||||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
|
||||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
|
||||||
content=item.get("content"),
|
|
||||||
retriever_from=item.get("retriever_from"),
|
|
||||||
created_by=self._user_id,
|
|
||||||
)
|
|
||||||
db.session.add(dataset_retriever_resource)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
self._queue_manager.publish(
|
self._queue_manager.publish(
|
||||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
@ -48,25 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "ssl_verify" not in kwargs:
|
||||||
|
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
|
||||||
|
|
||||||
|
ssl_verify = kwargs.pop("ssl_verify")
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
if dify_config.SSRF_PROXY_ALL_URL:
|
if dify_config.SSRF_PROXY_ALL_URL:
|
||||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||||
proxy_mounts = {
|
proxy_mounts = {
|
||||||
"http://": httpx.HTTPTransport(
|
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
|
||||||
proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
|
||||||
),
|
|
||||||
"https://": httpx.HTTPTransport(
|
|
||||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
else:
|
else:
|
||||||
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
|
with httpx.Client(verify=ssl_verify) as client:
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
if response.status_code not in STATUS_FORCELIST:
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
|
@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
|
|||||||
GENERATOR_QA_PROMPT,
|
GENERATOR_QA_PROMPT,
|
||||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||||
|
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||||
)
|
)
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
@ -340,3 +341,37 @@ class LLMGenerator:
|
|||||||
|
|
||||||
answer = cast(str, response.message.content)
|
answer = cast(str, response.message.content)
|
||||||
return answer.strip()
|
return answer.strip()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
provider=model_config.get("provider", ""),
|
||||||
|
model=model_config.get("name", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_messages = [
|
||||||
|
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||||
|
UserPromptMessage(content=instruction),
|
||||||
|
]
|
||||||
|
model_parameters = model_config.get("model_parameters", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = cast(
|
||||||
|
LLMResult,
|
||||||
|
model_instance.invoke_llm(
|
||||||
|
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_json_schema = cast(str, response.message.content)
|
||||||
|
return {"output": generated_json_schema, "error": ""}
|
||||||
|
|
||||||
|
except InvokeError as e:
|
||||||
|
error = str(e)
|
||||||
|
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
|
||||||
|
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||||
|
@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}}
|
|||||||
|
|
||||||
You just need to generate the output
|
You just need to generate the output
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
|
||||||
|
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
|
||||||
|
|
||||||
|
## Instructions:
|
||||||
|
|
||||||
|
1. Analyze the user's description of their data needs
|
||||||
|
2. Identify each property that should be included in the schema
|
||||||
|
3. Determine the appropriate data type for each property
|
||||||
|
4. Decide which properties should be required
|
||||||
|
5. Generate a complete JSON Schema with proper syntax
|
||||||
|
6. Include appropriate constraints when specified (min/max values, patterns, formats)
|
||||||
|
7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
|
||||||
|
8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
|
||||||
|
|
||||||
|
## Examples:
|
||||||
|
|
||||||
|
### Example 1:
|
||||||
|
**User Input:** I need name and age
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" },
|
||||||
|
"age": { "type": "number" }
|
||||||
|
},
|
||||||
|
"required": ["name", "age"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Example 2:
|
||||||
|
**User Input:** I want to store information about books including title, author, publication year and optional page count
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": { "type": "string" },
|
||||||
|
"author": { "type": "string" },
|
||||||
|
"publicationYear": { "type": "integer" },
|
||||||
|
"pageCount": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": ["title", "author", "publicationYear"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Example 3:
|
||||||
|
**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"email": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "email"
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 8
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 18
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["email", "password", "age"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Example 4:
|
||||||
|
**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"properties": {
|
||||||
|
"songs": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"duration": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"aritst": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"name",
|
||||||
|
"id",
|
||||||
|
"duration",
|
||||||
|
"aritst"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"songs"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Now, generate a JSON Schema based on my description
|
||||||
|
""" # noqa: E501
|
||||||
|
@ -44,6 +44,7 @@ class TokenBufferMemory:
|
|||||||
Message.created_at,
|
Message.created_at,
|
||||||
Message.workflow_run_id,
|
Message.workflow_run_id,
|
||||||
Message.parent_message_id,
|
Message.parent_message_id,
|
||||||
|
Message.answer_tokens,
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
Message.conversation_id == self.conversation.id,
|
Message.conversation_id == self.conversation.id,
|
||||||
@ -63,7 +64,7 @@ class TokenBufferMemory:
|
|||||||
thread_messages = extract_thread_messages(messages)
|
thread_messages = extract_thread_messages(messages)
|
||||||
|
|
||||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||||
if thread_messages and not thread_messages[0].answer:
|
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||||
thread_messages.pop(0)
|
thread_messages.pop(0)
|
||||||
|
|
||||||
messages = list(reversed(thread_messages))
|
messages = list(reversed(thread_messages))
|
||||||
|
@ -2,7 +2,7 @@ from decimal import Decimal
|
|||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
|
||||||
@ -85,6 +85,7 @@ class ModelFeature(Enum):
|
|||||||
DOCUMENT = "document"
|
DOCUMENT = "document"
|
||||||
VIDEO = "video"
|
VIDEO = "video"
|
||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
|
STRUCTURED_OUTPUT = "structured-output"
|
||||||
|
|
||||||
|
|
||||||
class DefaultParameterName(StrEnum):
|
class DefaultParameterName(StrEnum):
|
||||||
@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel):
|
|||||||
parameter_rules: list[ParameterRule] = []
|
parameter_rules: list[ParameterRule] = []
|
||||||
pricing: Optional[PriceConfig] = None
|
pricing: Optional[PriceConfig] = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_model(self):
|
||||||
|
supported_schema_keys = ["json_schema"]
|
||||||
|
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||||
|
if not schema_key:
|
||||||
|
return self
|
||||||
|
if self.features is None:
|
||||||
|
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||||
|
else:
|
||||||
|
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||||
|
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ModelUsage(BaseModel):
|
class ModelUsage(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from collections.abc import Generator, Sequence
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_tool_call_id() -> str:
|
||||||
|
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _increase_tool_call(
|
||||||
|
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Merge incremental tool call updates into existing tool calls.
|
||||||
|
|
||||||
|
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||||
|
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_tool_call(tool_call_id: str):
|
||||||
|
"""
|
||||||
|
Get or create a tool call by ID
|
||||||
|
|
||||||
|
:param tool_call_id: tool call ID
|
||||||
|
:return: existing or new tool call
|
||||||
|
"""
|
||||||
|
if not tool_call_id:
|
||||||
|
return existing_tools_calls[-1]
|
||||||
|
|
||||||
|
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||||
|
if _tool_call is None:
|
||||||
|
_tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
type="function",
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||||
|
)
|
||||||
|
existing_tools_calls.append(_tool_call)
|
||||||
|
|
||||||
|
return _tool_call
|
||||||
|
|
||||||
|
for new_tool_call in new_tool_calls:
|
||||||
|
# generate ID for tool calls with function name but no ID to track them
|
||||||
|
if new_tool_call.function.name and not new_tool_call.id:
|
||||||
|
new_tool_call.id = _gen_tool_call_id()
|
||||||
|
# get tool call
|
||||||
|
tool_call = get_tool_call(new_tool_call.id)
|
||||||
|
# update tool call
|
||||||
|
if new_tool_call.id:
|
||||||
|
tool_call.id = new_tool_call.id
|
||||||
|
if new_tool_call.type:
|
||||||
|
tool_call.type = new_tool_call.type
|
||||||
|
if new_tool_call.function.name:
|
||||||
|
tool_call.function.name = new_tool_call.function.name
|
||||||
|
if new_tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
|
||||||
|
|
||||||
class LargeLanguageModel(AIModel):
|
class LargeLanguageModel(AIModel):
|
||||||
"""
|
"""
|
||||||
Model class for large language model.
|
Model class for large language model.
|
||||||
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
|
|||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
|
||||||
def get_tool_call(tool_name: str):
|
|
||||||
if not tool_name:
|
|
||||||
return tools_calls[-1]
|
|
||||||
|
|
||||||
tool_call = next(
|
|
||||||
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
|
|
||||||
)
|
|
||||||
if tool_call is None:
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id="",
|
|
||||||
type="",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
|
||||||
)
|
|
||||||
tools_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
|
||||||
# get tool call
|
|
||||||
tool_call = get_tool_call(new_tool_call.function.name)
|
|
||||||
# update tool call
|
|
||||||
if new_tool_call.id:
|
|
||||||
tool_call.id = new_tool_call.id
|
|
||||||
if new_tool_call.type:
|
|
||||||
tool_call.type = new_tool_call.type
|
|
||||||
if new_tool_call.function.name:
|
|
||||||
tool_call.function.name = new_tool_call.function.name
|
|
||||||
if new_tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
|
||||||
|
|
||||||
for chunk in result:
|
for chunk in result:
|
||||||
if isinstance(chunk.delta.message.content, str):
|
if isinstance(chunk.delta.message.content, str):
|
||||||
content += chunk.delta.message.content
|
content += chunk.delta.message.content
|
||||||
elif isinstance(chunk.delta.message.content, list):
|
elif isinstance(chunk.delta.message.content, list):
|
||||||
content_list.extend(chunk.delta.message.content)
|
content_list.extend(chunk.delta.message.content)
|
||||||
if chunk.delta.message.tool_calls:
|
if chunk.delta.message.tool_calls:
|
||||||
increase_tool_call(chunk.delta.message.tool_calls)
|
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||||
|
|
||||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||||
system_fingerprint = chunk.system_fingerprint
|
system_fingerprint = chunk.system_fingerprint
|
||||||
|
@ -5,6 +5,7 @@ from datetime import datetime, timedelta
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langfuse import Langfuse # type: ignore
|
from langfuse import Langfuse # type: ignore
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import LangfuseConfig
|
from core.ops.entities.config_entity import LangfuseConfig
|
||||||
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||||||
UnitEnum,
|
UnitEnum,
|
||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values
|
from core.ops.utils import filter_none_values
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
)
|
)
|
||||||
self.add_trace(langfuse_trace_data=trace_data)
|
self.add_trace(langfuse_trace_data=trace_data)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not node_execution:
|
for node_execution in workflow_node_executions:
|
||||||
continue
|
|
||||||
|
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = node_execution.tenant_id
|
||||||
app_id = node_execution.app_id
|
app_id = node_execution.app_id
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from langsmith import Client
|
from langsmith import Client
|
||||||
from langsmith.schemas import RunBase
|
from langsmith.schemas import RunBase
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import LangSmithConfig
|
from core.ops.entities.config_entity import LangSmithConfig
|
||||||
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||||||
LangSmithRunUpdateModel,
|
LangSmithRunUpdateModel,
|
||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models.model import EndUser, MessageFile
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
self.add_run(langsmith_run)
|
self.add_run(langsmith_run)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
params={
|
||||||
.all()
|
"tenant_id": trace_info.tenant_id,
|
||||||
|
"app_id": trace_info.metadata.get("app_id"),
|
||||||
|
"session_factory": session_factory,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not node_execution:
|
for node_execution in workflow_node_executions:
|
||||||
continue
|
|
||||||
|
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = node_execution.tenant_id
|
||||||
app_id = node_execution.app_id
|
app_id = node_execution.app_id
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
from opik import Opik, Trace
|
from opik import Opik, Trace
|
||||||
from opik.id_helpers import uuid4_to_uuid7
|
from opik.id_helpers import uuid4_to_uuid7
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
from core.ops.entities.config_entity import OpikConfig
|
from core.ops.entities.config_entity import OpikConfig
|
||||||
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
|
|||||||
TraceTaskName,
|
TraceTaskName,
|
||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models.model import EndUser, MessageFile
|
||||||
from models.workflow import WorkflowNodeExecution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
}
|
}
|
||||||
self.add_trace(trace_data)
|
self.add_trace(trace_data)
|
||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
workflow_nodes_execution_id_records = (
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
db.session.query(WorkflowNodeExecution.id)
|
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
params={
|
||||||
.all()
|
"tenant_id": trace_info.tenant_id,
|
||||||
|
"app_id": trace_info.metadata.get("app_id"),
|
||||||
|
"session_factory": session_factory,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
# Get all executions for this workflow run
|
||||||
node_execution = (
|
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||||
db.session.query(
|
workflow_run_id=trace_info.workflow_run_id
|
||||||
WorkflowNodeExecution.id,
|
|
||||||
WorkflowNodeExecution.tenant_id,
|
|
||||||
WorkflowNodeExecution.app_id,
|
|
||||||
WorkflowNodeExecution.title,
|
|
||||||
WorkflowNodeExecution.node_type,
|
|
||||||
WorkflowNodeExecution.status,
|
|
||||||
WorkflowNodeExecution.inputs,
|
|
||||||
WorkflowNodeExecution.outputs,
|
|
||||||
WorkflowNodeExecution.created_at,
|
|
||||||
WorkflowNodeExecution.elapsed_time,
|
|
||||||
WorkflowNodeExecution.process_data,
|
|
||||||
WorkflowNodeExecution.execution_metadata,
|
|
||||||
)
|
|
||||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not node_execution:
|
for node_execution in workflow_node_executions:
|
||||||
continue
|
|
||||||
|
|
||||||
node_execution_id = node_execution.id
|
node_execution_id = node_execution.id
|
||||||
tenant_id = node_execution.tenant_id
|
tenant_id = node_execution.tenant_id
|
||||||
app_id = node_execution.app_id
|
app_id = node_execution.app_id
|
||||||
|
@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||||
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
|
|||||||
|
|
||||||
|
|
||||||
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||||
|
@classmethod
|
||||||
|
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
|
||||||
|
"""
|
||||||
|
Fetch app info
|
||||||
|
"""
|
||||||
|
app = cls._get_app(app_id, tenant_id)
|
||||||
|
|
||||||
|
"""Retrieve app parameters."""
|
||||||
|
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||||
|
workflow = app.workflow
|
||||||
|
if workflow is None:
|
||||||
|
raise ValueError("unexpected app type")
|
||||||
|
|
||||||
|
features_dict = workflow.features_dict
|
||||||
|
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||||
|
else:
|
||||||
|
app_model_config = app.app_model_config
|
||||||
|
if app_model_config is None:
|
||||||
|
raise ValueError("unexpected app type")
|
||||||
|
|
||||||
|
features_dict = app_model_config.to_dict()
|
||||||
|
|
||||||
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def invoke_app(
|
def invoke_app(
|
||||||
cls,
|
cls,
|
||||||
|
@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
:param query: str
|
:param query: str
|
||||||
:return: dict
|
:return: dict
|
||||||
"""
|
"""
|
||||||
|
# FIXME(-LAN-): Avoid import service into core
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
node_id = "1919810"
|
node_id = "1919810"
|
||||||
node_data = ParameterExtractorNodeData(
|
node_data = ParameterExtractorNodeData(
|
||||||
@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
:param query: str
|
:param query: str
|
||||||
:return: dict
|
:return: dict
|
||||||
"""
|
"""
|
||||||
|
# FIXME(-LAN-): Avoid import service into core
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
node_id = "1919810"
|
node_id = "1919810"
|
||||||
node_data = QuestionClassifierNodeData(
|
node_data = QuestionClassifierNodeData(
|
||||||
|
@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
|
|||||||
|
|
||||||
filename: str
|
filename: str
|
||||||
mimetype: str
|
mimetype: str
|
||||||
|
|
||||||
|
|
||||||
|
class RequestFetchAppInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Request to fetch app info
|
||||||
|
"""
|
||||||
|
|
||||||
|
app_id: str
|
||||||
|
@ -124,6 +124,15 @@ class ProviderManager:
|
|||||||
|
|
||||||
# Get All preferred provider types of the workspace
|
# Get All preferred provider types of the workspace
|
||||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||||
|
# Ensure that both the original provider name and its ModelProviderID string representation
|
||||||
|
# are present in the dictionary to handle cases where either form might be used
|
||||||
|
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
|
||||||
|
provider_id = ModelProviderID(provider_name)
|
||||||
|
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
|
||||||
|
# Add the ModelProviderID string representation if it's not already present
|
||||||
|
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
|
||||||
|
provider_name_to_preferred_model_provider_records_dict[provider_name]
|
||||||
|
)
|
||||||
|
|
||||||
# Get All provider model settings
|
# Get All provider model settings
|
||||||
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
|
||||||
@ -497,8 +506,8 @@ class ProviderManager:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_trial_provider_records(
|
def _init_trial_provider_records(
|
||||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
|
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list[Provider]]:
|
||||||
"""
|
"""
|
||||||
Initialize trial provider records if not exists.
|
Initialize trial provider records if not exists.
|
||||||
|
|
||||||
@ -532,7 +541,7 @@ class ProviderManager:
|
|||||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||||
try:
|
try:
|
||||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||||
provider_record = Provider(
|
new_provider_record = Provider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
# TODO: Use provider name with prefix after the data migration.
|
# TODO: Use provider name with prefix after the data migration.
|
||||||
provider_name=ModelProviderID(provider_name).provider_name,
|
provider_name=ModelProviderID(provider_name).provider_name,
|
||||||
@ -542,11 +551,12 @@ class ProviderManager:
|
|||||||
quota_used=0,
|
quota_used=0,
|
||||||
is_valid=True,
|
is_valid=True,
|
||||||
)
|
)
|
||||||
db.session.add(provider_record)
|
db.session.add(new_provider_record)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
provider_record = (
|
existed_provider_record = (
|
||||||
db.session.query(Provider)
|
db.session.query(Provider)
|
||||||
.filter(
|
.filter(
|
||||||
Provider.tenant_id == tenant_id,
|
Provider.tenant_id == tenant_id,
|
||||||
@ -556,11 +566,14 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if provider_record and not provider_record.is_valid:
|
if not existed_provider_record:
|
||||||
provider_record.is_valid = True
|
continue
|
||||||
|
|
||||||
|
if not existed_provider_record.is_valid:
|
||||||
|
existed_provider_record.is_valid = True
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
provider_name_to_provider_records_dict[provider_name].append(provider_record)
|
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
|
||||||
|
|
||||||
return provider_name_to_provider_records_dict
|
return provider_name_to_provider_records_dict
|
||||||
|
|
||||||
|
@ -246,7 +246,7 @@ class AnalyticdbVectorBySql:
|
|||||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||||
ORDER BY (score,id) DESC
|
ORDER BY score DESC, id DESC
|
||||||
LIMIT {top_k}""",
|
LIMIT {top_k}""",
|
||||||
(f"'{query}'", f"'{query}'"),
|
(f"'{query}'", f"'{query}'"),
|
||||||
)
|
)
|
||||||
|
@ -32,6 +32,7 @@ class MilvusConfig(BaseModel):
|
|||||||
batch_size: int = 100 # Batch size for operations
|
batch_size: int = 100 # Batch size for operations
|
||||||
database: str = "default" # Database name
|
database: str = "default" # Database name
|
||||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||||
|
analyzer_params: Optional[str] = None # Analyzer params
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -58,6 +59,7 @@ class MilvusConfig(BaseModel):
|
|||||||
"user": self.user,
|
"user": self.user,
|
||||||
"password": self.password,
|
"password": self.password,
|
||||||
"db_name": self.database,
|
"db_name": self.database,
|
||||||
|
"analyzer_params": self.analyzer_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -300,14 +302,19 @@ class MilvusVector(BaseVector):
|
|||||||
|
|
||||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||||
fields.append(
|
content_field_kwargs: dict[str, Any] = {
|
||||||
FieldSchema(
|
"max_length": 65_535,
|
||||||
Field.CONTENT_KEY.value,
|
"enable_analyzer": self._hybrid_search_enabled,
|
||||||
DataType.VARCHAR,
|
}
|
||||||
max_length=65_535,
|
if (
|
||||||
enable_analyzer=self._hybrid_search_enabled,
|
self._hybrid_search_enabled
|
||||||
)
|
and self._client_config.analyzer_params is not None
|
||||||
)
|
and self._client_config.analyzer_params.strip()
|
||||||
|
):
|
||||||
|
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
|
||||||
|
|
||||||
|
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
|
||||||
|
|
||||||
# Create the primary key field
|
# Create the primary key field
|
||||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||||
# Create the vector field, supports binary or float vectors
|
# Create the vector field, supports binary or float vectors
|
||||||
@ -383,5 +390,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
|||||||
password=dify_config.MILVUS_PASSWORD or "",
|
password=dify_config.MILVUS_PASSWORD or "",
|
||||||
database=dify_config.MILVUS_DATABASE or "",
|
database=dify_config.MILVUS_DATABASE or "",
|
||||||
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||||
|
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor):
|
|||||||
|
|
||||||
db.session.add(upload_file)
|
db.session.add(upload_file)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
image_map[rel.target_part] = (
|
image_map[rel.target_part] = f""
|
||||||
f""
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_map
|
return image_map
|
||||||
|
|
||||||
|
@ -39,6 +39,12 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||||||
else:
|
else:
|
||||||
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
||||||
|
|
||||||
|
def _character_encoder(texts: list[str]) -> list[int]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [len(text) for text in texts]
|
||||||
|
|
||||||
if issubclass(cls, TokenTextSplitter):
|
if issubclass(cls, TokenTextSplitter):
|
||||||
extra_kwargs = {
|
extra_kwargs = {
|
||||||
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
||||||
@ -47,7 +53,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||||||
}
|
}
|
||||||
kwargs = {**kwargs, **extra_kwargs}
|
kwargs = {**kwargs, **extra_kwargs}
|
||||||
|
|
||||||
return cls(length_function=_token_encoder, **kwargs)
|
return cls(length_function=_character_encoder, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||||
@ -103,7 +109,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|||||||
_good_splits_lengths = [] # cache the lengths of the splits
|
_good_splits_lengths = [] # cache the lengths of the splits
|
||||||
_separator = "" if self._keep_separator else separator
|
_separator = "" if self._keep_separator else separator
|
||||||
s_lens = self._length_function(splits)
|
s_lens = self._length_function(splits)
|
||||||
if _separator != "":
|
if separator != "":
|
||||||
for s, s_len in zip(splits, s_lens):
|
for s, s_len in zip(splits, s_lens):
|
||||||
if s_len < self._chunk_size:
|
if s_len < self._chunk_size:
|
||||||
_good_splits.append(s)
|
_good_splits.append(s)
|
||||||
|
15
api/core/repository/__init__.py
Normal file
15
api/core/repository/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Repository interfaces for data access.
|
||||||
|
|
||||||
|
This package contains repository interfaces that define the contract
|
||||||
|
for accessing and manipulating data, regardless of the underlying
|
||||||
|
storage mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
|
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RepositoryFactory",
|
||||||
|
"WorkflowNodeExecutionRepository",
|
||||||
|
]
|
97
api/core/repository/repository_factory.py
Normal file
97
api/core/repository/repository_factory.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
Repository factory for creating repository instances.
|
||||||
|
|
||||||
|
This module provides a simple factory interface for creating repository instances.
|
||||||
|
It does not contain any implementation details or dependencies on specific repositories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Mapping
|
||||||
|
from typing import Any, Literal, Optional, cast
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||||
|
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||||
|
|
||||||
|
# Type for workflow node execution factory function
|
||||||
|
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||||
|
|
||||||
|
# Repository type literals
|
||||||
|
_RepositoryType = Literal["workflow_node_execution"]
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryFactory:
|
||||||
|
"""
|
||||||
|
Factory class for creating repository instances.
|
||||||
|
|
||||||
|
This factory delegates the actual repository creation to implementation-specific
|
||||||
|
factory functions that are registered with the factory at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dictionary to store factory functions
|
||||||
|
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||||
|
"""
|
||||||
|
Register a factory function for a specific repository type.
|
||||||
|
This is a private method and should not be called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||||
|
factory_func: A function that takes parameters and returns a repository instance
|
||||||
|
"""
|
||||||
|
cls._factory_functions[repository_type] = factory_func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||||
|
"""
|
||||||
|
Create a new repository instance with the provided parameters.
|
||||||
|
This is a private method and should not be called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository_type: The type of repository to create
|
||||||
|
params: A dictionary of parameters to pass to the factory function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the requested repository
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no factory function is registered for the repository type
|
||||||
|
"""
|
||||||
|
if repository_type not in cls._factory_functions:
|
||||||
|
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||||
|
|
||||||
|
# Use empty dict if params is None
|
||||||
|
params = params or {}
|
||||||
|
|
||||||
|
return cls._factory_functions[repository_type](params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||||
|
"""
|
||||||
|
Register a factory function for the workflow node execution repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||||
|
"""
|
||||||
|
cls._register_factory("workflow_node_execution", factory_func)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_workflow_node_execution_repository(
|
||||||
|
cls, params: Optional[Mapping[str, Any]] = None
|
||||||
|
) -> WorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: A dictionary of parameters to pass to the factory function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||||
|
"""
|
||||||
|
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||||
|
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
97
api/core/repository/workflow_node_execution_repository.py
Normal file
97
api/core/repository/workflow_node_execution_repository.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Optional, Protocol
|
||||||
|
|
||||||
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrderConfig:
|
||||||
|
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||||
|
|
||||||
|
order_by: list[str]
|
||||||
|
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecutionRepository(Protocol):
|
||||||
|
"""
|
||||||
|
Repository interface for WorkflowNodeExecution.
|
||||||
|
|
||||||
|
This interface defines the contract for accessing and manipulating
|
||||||
|
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||||
|
|
||||||
|
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||||
|
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||||
|
the core interface. This keeps the core domain model clean and independent of specific
|
||||||
|
application domains or deployment scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Save a WorkflowNodeExecution instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to save
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WorkflowNodeExecution instance if found, None otherwise
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_workflow_run(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
order_config: Optional configuration for ordering results
|
||||||
|
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||||
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of running WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing WorkflowNodeExecution instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to update
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||||
|
|
||||||
|
This method is intended to be used for bulk deletion operations, such as removing
|
||||||
|
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||||
|
"""
|
||||||
|
...
|
@ -16,7 +16,7 @@ from core.variables.segments import StringSegment
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
|
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||||
@ -251,7 +251,12 @@ class AgentNode(ToolNode):
|
|||||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||||
]
|
]
|
||||||
value["history_prompt_messages"] = history_prompt_messages
|
value["history_prompt_messages"] = history_prompt_messages
|
||||||
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
|
if model_schema:
|
||||||
|
# remove structured output feature to support old version agent plugin
|
||||||
|
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||||
|
value["entity"] = model_schema.model_dump(mode="json")
|
||||||
|
else:
|
||||||
|
value["entity"] = None
|
||||||
result[parameter_name] = value
|
result[parameter_name] = value
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -348,3 +353,10 @@ class AgentNode(ToolNode):
|
|||||||
)
|
)
|
||||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
return model_instance, model_schema
|
return model_instance, model_schema
|
||||||
|
|
||||||
|
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||||
|
if model_schema.features:
|
||||||
|
for feature in model_schema.features:
|
||||||
|
if feature.value not in AgentOldVersionModelFeatures:
|
||||||
|
model_schema.features.remove(feature)
|
||||||
|
return model_schema
|
||||||
|
@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData):
|
|||||||
class ParamsAutoGenerated(Enum):
|
class ParamsAutoGenerated(Enum):
|
||||||
CLOSE = 0
|
CLOSE = 0
|
||||||
OPEN = 1
|
OPEN = 1
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOldVersionModelFeatures(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for old SDK version llm feature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOOL_CALL = "tool-call"
|
||||||
|
MULTI_TOOL_CALL = "multi-tool-call"
|
||||||
|
AGENT_THOUGHT = "agent-thought"
|
||||||
|
VISION = "vision"
|
||||||
|
STREAM_TOOL_CALL = "stream-tool-call"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
for answer_node_id, route_position in self.route_position.items():
|
for answer_node_id, route_position in self.route_position.items():
|
||||||
if answer_node_id not in self.rest_node_ids:
|
if answer_node_id not in self.rest_node_ids:
|
||||||
continue
|
continue
|
||||||
# exclude current node id
|
# Remove current node id from answer dependencies to support stream output if it is a success branch
|
||||||
answer_dependencies = self.generate_routes.answer_dependencies
|
answer_dependencies = self.generate_routes.answer_dependencies
|
||||||
if event.node_id in answer_dependencies[answer_node_id]:
|
edge_mapping = self.graph.edge_mapping.get(event.node_id)
|
||||||
|
success_edge = (
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
edge
|
||||||
|
for edge in edge_mapping
|
||||||
|
if edge.run_condition
|
||||||
|
and edge.run_condition.type == "branch_identify"
|
||||||
|
and edge.run_condition.branch_identify == "success-branch"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if edge_mapping
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
event.node_id in answer_dependencies[answer_node_id]
|
||||||
|
and success_edge
|
||||||
|
and success_edge.target_node_id == answer_node_id
|
||||||
|
):
|
||||||
answer_dependencies[answer_node_id].remove(event.node_id)
|
answer_dependencies[answer_node_id].remove(event.node_id)
|
||||||
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
||||||
# all depends on answer node id not in rest node ids
|
# all depends on answer node id not in rest node ids
|
||||||
|
@ -90,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
|
|||||||
params: str
|
params: str
|
||||||
body: Optional[HttpRequestNodeBody] = None
|
body: Optional[HttpRequestNodeBody] = None
|
||||||
timeout: Optional[HttpRequestNodeTimeout] = None
|
timeout: Optional[HttpRequestNodeTimeout] = None
|
||||||
|
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||||
|
|
||||||
|
|
||||||
class Response:
|
class Response:
|
||||||
|
@ -88,6 +88,7 @@ class Executor:
|
|||||||
self.method = node_data.method
|
self.method = node_data.method
|
||||||
self.auth = node_data.authorization
|
self.auth = node_data.authorization
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.ssl_verify = node_data.ssl_verify
|
||||||
self.params = []
|
self.params = []
|
||||||
self.headers = {}
|
self.headers = {}
|
||||||
self.content = None
|
self.content = None
|
||||||
@ -316,6 +317,7 @@ class Executor:
|
|||||||
"headers": headers,
|
"headers": headers,
|
||||||
"params": self.params,
|
"params": self.params,
|
||||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||||
|
"ssl_verify": self.ssl_verify,
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
"max_retries": self.max_retries,
|
"max_retries": self.max_retries,
|
||||||
}
|
}
|
||||||
|
@ -51,6 +51,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||||
},
|
},
|
||||||
|
"ssl_verify": dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||||
},
|
},
|
||||||
"retry_config": {
|
"retry_config": {
|
||||||
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||||
|
@ -149,7 +149,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||||||
def _extract_slice(
|
def _extract_slice(
|
||||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - 1
|
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
|
||||||
|
if value < 1:
|
||||||
|
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||||
|
value -= 1
|
||||||
if len(variable.value) > int(value):
|
if len(variable.value) > int(value):
|
||||||
result = variable.value[value]
|
result = variable.value[value]
|
||||||
else:
|
else:
|
||||||
|
@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
|
|||||||
memory: Optional[MemoryConfig] = None
|
memory: Optional[MemoryConfig] = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
|
structured_output: dict | None = None
|
||||||
|
structured_output_enabled: bool = False
|
||||||
|
|
||||||
@field_validator("prompt_config", mode="before")
|
@field_validator("prompt_config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
@ -57,6 +65,12 @@ from core.workflow.nodes.event import (
|
|||||||
RunRetrieverResourceEvent,
|
RunRetrieverResourceEvent,
|
||||||
RunStreamChunkEvent,
|
RunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.utils.structured_output.entities import (
|
||||||
|
ResponseFormat,
|
||||||
|
SpecialModelType,
|
||||||
|
SupportStructuredOutputStatus,
|
||||||
|
)
|
||||||
|
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
_node_type = NodeType.LLM
|
_node_type = NodeType.LLM
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||||
|
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||||
|
"""Process structured output if enabled"""
|
||||||
|
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||||
|
return None
|
||||||
|
return self._parse_structured_output(text)
|
||||||
|
|
||||||
node_inputs: Optional[dict[str, Any]] = None
|
node_inputs: Optional[dict[str, Any]] = None
|
||||||
process_data = None
|
process_data = None
|
||||||
result_text = ""
|
result_text = ""
|
||||||
@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
if isinstance(event, RunRetrieverResourceEvent):
|
if isinstance(event, RunRetrieverResourceEvent):
|
||||||
context = event.context
|
context = event.context
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
node_inputs["#context#"] = context
|
node_inputs["#context#"] = context
|
||||||
|
|
||||||
@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
break
|
break
|
||||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||||
|
structured_output = process_structured_output(result_text)
|
||||||
|
if structured_output:
|
||||||
|
outputs["structured_output"] = structured_output
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||||
|
support_structured_output = self._check_model_structured_output_support()
|
||||||
|
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
||||||
|
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||||
|
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||||
|
# Set appropriate response format based on model capabilities
|
||||||
|
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||||
return model_instance, ModelConfigWithCredentialsEntity(
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
support_structured_output = self._check_model_structured_output_support()
|
||||||
|
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||||
|
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||||
|
prompt_messages=filtered_prompt_messages,
|
||||||
|
)
|
||||||
stop = model_config.stop
|
stop = model_config.stop
|
||||||
return filtered_prompt_messages, stop
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
|
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
||||||
|
structured_output: dict[str, Any] | list[Any] = {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(result_text)
|
||||||
|
if not isinstance(parsed, (dict | list)):
|
||||||
|
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = parsed
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# if the result_text is not a valid json, try to repair it
|
||||||
|
parsed = json_repair.loads(result_text)
|
||||||
|
if not isinstance(parsed, (dict | list)):
|
||||||
|
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = parsed
|
||||||
|
return structured_output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||||
provider_model_bundle = model_instance.provider_model_bundle
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||||
|
"""
|
||||||
|
Handle structured output for models with native JSON schema support.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
:return: Updated model parameters with JSON schema configuration
|
||||||
|
"""
|
||||||
|
# Process schema according to model requirements
|
||||||
|
schema = self._fetch_structured_output_schema()
|
||||||
|
schema_json = self._prepare_schema_for_model(schema)
|
||||||
|
|
||||||
|
# Set JSON schema in parameters
|
||||||
|
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Set appropriate response format if required by the model
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||||
|
|
||||||
|
return model_parameters
|
||||||
|
|
||||||
|
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Handle structured output for models without native JSON schema support.
|
||||||
|
This function modifies the prompt messages to include schema-based output requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_messages: Original sequence of prompt messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||||
|
"""
|
||||||
|
# Convert schema to string format
|
||||||
|
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
||||||
|
|
||||||
|
# Find existing system prompt with schema placeholder
|
||||||
|
system_prompt = next(
|
||||||
|
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||||
|
# Prepare system prompt content
|
||||||
|
system_prompt_content = (
|
||||||
|
structured_output_prompt + "\n\n" + system_prompt.content
|
||||||
|
if system_prompt and isinstance(system_prompt.content, str)
|
||||||
|
else structured_output_prompt
|
||||||
|
)
|
||||||
|
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||||
|
|
||||||
|
# Extract content from the last user message
|
||||||
|
|
||||||
|
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||||
|
updated_prompt = [system_prompt] + filtered_prompts
|
||||||
|
|
||||||
|
return updated_prompt
|
||||||
|
|
||||||
|
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
||||||
|
"""
|
||||||
|
Set the appropriate response format parameter based on model rules.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
"""
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format":
|
||||||
|
if ResponseFormat.JSON.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||||
|
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||||
|
|
||||||
|
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Prepare JSON schema based on model requirements.
|
||||||
|
|
||||||
|
Different models have different requirements for JSON schema formatting.
|
||||||
|
This function handles these differences.
|
||||||
|
|
||||||
|
:param schema: The original JSON schema
|
||||||
|
:return: Processed schema compatible with the current model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Deep copy to avoid modifying the original schema
|
||||||
|
processed_schema = schema.copy()
|
||||||
|
|
||||||
|
# Convert boolean types to string types (common requirement)
|
||||||
|
convert_boolean_to_string(processed_schema)
|
||||||
|
|
||||||
|
# Apply model-specific transformations
|
||||||
|
if SpecialModelType.GEMINI in self.node_data.model.name:
|
||||||
|
remove_additional_properties(processed_schema)
|
||||||
|
return processed_schema
|
||||||
|
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
||||||
|
return processed_schema
|
||||||
|
else:
|
||||||
|
# Default format with name field
|
||||||
|
return {"schema": processed_schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
Fetch model schema
|
||||||
|
"""
|
||||||
|
model_name = self.node_data.model.name
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||||
|
)
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
model_credentials = model_instance.credentials
|
||||||
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
|
return model_schema
|
||||||
|
|
||||||
|
def _fetch_structured_output_schema(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Fetch the structured output schema from the node data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: The structured output schema
|
||||||
|
"""
|
||||||
|
if not self.node_data.structured_output:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
|
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
||||||
|
if not structured_output_schema:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = json.loads(structured_output_schema)
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||||
|
return schema
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||||
|
|
||||||
|
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
||||||
|
"""
|
||||||
|
Check if the current model supports structured output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SupportStructuredOutput: The support status of structured output
|
||||||
|
"""
|
||||||
|
# Early return if structured output is disabled
|
||||||
|
if (
|
||||||
|
not isinstance(self.node_data, LLMNodeData)
|
||||||
|
or not self.node_data.structured_output_enabled
|
||||||
|
or not self.node_data.structured_output
|
||||||
|
):
|
||||||
|
return SupportStructuredOutputStatus.DISABLED
|
||||||
|
# Get model schema and check if it exists
|
||||||
|
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||||
|
if not model_schema:
|
||||||
|
return SupportStructuredOutputStatus.DISABLED
|
||||||
|
|
||||||
|
# Check if model supports structured output feature
|
||||||
|
return (
|
||||||
|
SupportStructuredOutputStatus.SUPPORTED
|
||||||
|
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||||
|
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||||
match role:
|
match role:
|
||||||
@ -1064,3 +1269,49 @@ def _handle_completion_template(
|
|||||||
)
|
)
|
||||||
prompt_messages.append(prompt_message)
|
prompt_messages.append(prompt_message)
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def remove_additional_properties(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Remove additionalProperties fields from JSON schema.
|
||||||
|
Used for models like Gemini that don't support this property.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove additionalProperties at current level
|
||||||
|
schema.pop("additionalProperties", None)
|
||||||
|
|
||||||
|
# Process nested structures recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
remove_additional_properties(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
remove_additional_properties(item)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_boolean_to_string(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Convert boolean type specifications to string in JSON schema.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for boolean type at current level
|
||||||
|
if schema.get("type") == "boolean":
|
||||||
|
schema["type"] = "string"
|
||||||
|
|
||||||
|
# Process nested dictionaries and lists recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
convert_boolean_to_string(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
convert_boolean_to_string(item)
|
||||||
|
24
api/core/workflow/utils/structured_output/entities.py
Normal file
24
api/core/workflow/utils/structured_output/entities.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(StrEnum):
|
||||||
|
"""Constants for model response formats"""
|
||||||
|
|
||||||
|
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||||
|
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||||
|
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialModelType(StrEnum):
|
||||||
|
"""Constants for identifying model types"""
|
||||||
|
|
||||||
|
GEMINI = "gemini"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
class SupportStructuredOutputStatus(StrEnum):
|
||||||
|
"""Constants for structured output support status"""
|
||||||
|
|
||||||
|
SUPPORTED = "supported"
|
||||||
|
UNSUPPORTED = "unsupported"
|
||||||
|
DISABLED = "disabled"
|
17
api/core/workflow/utils/structured_output/prompt.py
Normal file
17
api/core/workflow/utils/structured_output/prompt.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||||
|
constraints:
|
||||||
|
- You must output in JSON format.
|
||||||
|
- Do not output boolean value, use string type instead.
|
||||||
|
- Do not output integer or float value, use number type instead.
|
||||||
|
eg:
|
||||||
|
Here is the JSON schema:
|
||||||
|
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||||
|
|
||||||
|
Here is the user's question:
|
||||||
|
My name is John Doe and I am 30 years old.
|
||||||
|
|
||||||
|
output:
|
||||||
|
{"name": "John Doe", "age": 30}
|
||||||
|
Here is the JSON schema:
|
||||||
|
{{schema}}
|
||||||
|
""" # noqa: E501
|
@ -26,9 +26,12 @@ def init_app(app: DifyApp):
|
|||||||
|
|
||||||
# Always add StreamHandler to log to console
|
# Always add StreamHandler to log to console
|
||||||
sh = logging.StreamHandler(sys.stdout)
|
sh = logging.StreamHandler(sys.stdout)
|
||||||
sh.addFilter(RequestIdFilter())
|
|
||||||
log_handlers.append(sh)
|
log_handlers.append(sh)
|
||||||
|
|
||||||
|
# Apply RequestIdFilter to all handlers
|
||||||
|
for handler in log_handlers:
|
||||||
|
handler.addFilter(RequestIdFilter())
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=dify_config.LOG_LEVEL,
|
level=dify_config.LOG_LEVEL,
|
||||||
format=dify_config.LOG_FORMAT,
|
format=dify_config.LOG_FORMAT,
|
||||||
|
18
api/extensions/ext_repositories.py
Normal file
18
api/extensions/ext_repositories.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Extension for initializing repositories.
|
||||||
|
|
||||||
|
This extension registers repository implementations with the RepositoryFactory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dify_app import DifyApp
|
||||||
|
from repositories.repository_registry import register_repositories
|
||||||
|
|
||||||
|
|
||||||
|
def init_app(_app: DifyApp) -> None:
|
||||||
|
"""
|
||||||
|
Initialize repository implementations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_app: The Flask application instance (unused)
|
||||||
|
"""
|
||||||
|
register_repositories()
|
@ -73,11 +73,7 @@ class Storage:
|
|||||||
raise ValueError(f"unsupported storage type {storage_type}")
|
raise ValueError(f"unsupported storage type {storage_type}")
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename, data):
|
||||||
try:
|
|
||||||
self.storage_runner.save(filename, data)
|
self.storage_runner.save(filename, data)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to save file {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
|
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
|
||||||
@ -86,49 +82,25 @@ class Storage:
|
|||||||
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
|
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
|
||||||
|
|
||||||
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
|
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
|
||||||
try:
|
|
||||||
if stream:
|
if stream:
|
||||||
return self.load_stream(filename)
|
return self.load_stream(filename)
|
||||||
else:
|
else:
|
||||||
return self.load_once(filename)
|
return self.load_once(filename)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to load file {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def load_once(self, filename: str) -> bytes:
|
def load_once(self, filename: str) -> bytes:
|
||||||
try:
|
|
||||||
return self.storage_runner.load_once(filename)
|
return self.storage_runner.load_once(filename)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to load_once file {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def load_stream(self, filename: str) -> Generator:
|
def load_stream(self, filename: str) -> Generator:
|
||||||
try:
|
|
||||||
return self.storage_runner.load_stream(filename)
|
return self.storage_runner.load_stream(filename)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to load_stream file {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def download(self, filename, target_filepath):
|
def download(self, filename, target_filepath):
|
||||||
try:
|
|
||||||
self.storage_runner.download(filename, target_filepath)
|
self.storage_runner.download(filename, target_filepath)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to download file {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def exists(self, filename):
|
def exists(self, filename):
|
||||||
try:
|
|
||||||
return self.storage_runner.exists(filename)
|
return self.storage_runner.exists(filename)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to check file exists {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete(self, filename):
|
def delete(self, filename):
|
||||||
try:
|
|
||||||
return self.storage_runner.delete(filename)
|
return self.storage_runner.delete(filename)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to delete file {filename}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
storage = Storage()
|
storage = Storage()
|
||||||
|
@ -52,6 +52,7 @@ def build_from_mapping(
|
|||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
config: FileUploadConfig | None = None,
|
config: FileUploadConfig | None = None,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||||
|
|
||||||
@ -69,6 +70,7 @@ def build_from_mapping(
|
|||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config and not _is_file_valid_with_config(
|
if config and not _is_file_valid_with_config(
|
||||||
@ -87,12 +89,14 @@ def build_from_mappings(
|
|||||||
mappings: Sequence[Mapping[str, Any]],
|
mappings: Sequence[Mapping[str, Any]],
|
||||||
config: FileUploadConfig | None = None,
|
config: FileUploadConfig | None = None,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> Sequence[File]:
|
) -> Sequence[File]:
|
||||||
files = [
|
files = [
|
||||||
build_from_mapping(
|
build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=config,
|
config=config,
|
||||||
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
for mapping in mappings
|
for mapping in mappings
|
||||||
]
|
]
|
||||||
@ -116,6 +120,7 @@ def _build_from_local_file(
|
|||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
upload_file_id = mapping.get("upload_file_id")
|
upload_file_id = mapping.get("upload_file_id")
|
||||||
if not upload_file_id:
|
if not upload_file_id:
|
||||||
@ -134,10 +139,16 @@ def _build_from_local_file(
|
|||||||
if row is None:
|
if row is None:
|
||||||
raise ValueError("Invalid upload file")
|
raise ValueError("Invalid upload file")
|
||||||
|
|
||||||
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||||
if file_type.value != mapping.get("type", "custom"):
|
specified_type = mapping.get("type", "custom")
|
||||||
|
|
||||||
|
if strict_type_validation and detected_file_type.value != specified_type:
|
||||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=row.name,
|
filename=row.name,
|
||||||
@ -158,6 +169,7 @@ def _build_from_remote_url(
|
|||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
upload_file_id = mapping.get("upload_file_id")
|
upload_file_id = mapping.get("upload_file_id")
|
||||||
if upload_file_id:
|
if upload_file_id:
|
||||||
@ -174,10 +186,21 @@ def _build_from_remote_url(
|
|||||||
if upload_file is None:
|
if upload_file is None:
|
||||||
raise ValueError("Invalid upload file")
|
raise ValueError("Invalid upload file")
|
||||||
|
|
||||||
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
|
detected_file_type = _standardize_file_type(
|
||||||
if file_type.value != mapping.get("type", "custom"):
|
extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||||
|
)
|
||||||
|
|
||||||
|
specified_type = mapping.get("type")
|
||||||
|
|
||||||
|
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type)
|
||||||
|
if specified_type and specified_type != FileType.CUSTOM.value
|
||||||
|
else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=upload_file.name,
|
filename=upload_file.name,
|
||||||
@ -237,6 +260,7 @@ def _build_from_tool_file(
|
|||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
|
strict_type_validation: bool = False,
|
||||||
) -> File:
|
) -> File:
|
||||||
tool_file = (
|
tool_file = (
|
||||||
db.session.query(ToolFile)
|
db.session.query(ToolFile)
|
||||||
@ -252,7 +276,16 @@ def _build_from_tool_file(
|
|||||||
|
|
||||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||||
|
|
||||||
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
|
||||||
|
|
||||||
|
specified_type = mapping.get("type")
|
||||||
|
|
||||||
|
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||||
|
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||||
|
|
||||||
|
file_type = (
|
||||||
|
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
|
||||||
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
|
@ -42,6 +42,7 @@ message_file_fields = {
|
|||||||
"size": fields.Integer,
|
"size": fields.Integer,
|
||||||
"transfer_method": fields.String,
|
"transfer_method": fields.String,
|
||||||
"belongs_to": fields.String(default="user"),
|
"belongs_to": fields.String(default="user"),
|
||||||
|
"upload_file_id": fields.String(default=None),
|
||||||
}
|
}
|
||||||
|
|
||||||
agent_thought_fields = {
|
agent_thought_fields = {
|
||||||
|
@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined]
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def retriever_resources(self):
|
def retriever_resources(self):
|
||||||
return (
|
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||||
db.session.query(DatasetRetrieverResource)
|
|
||||||
.filter(DatasetRetrieverResource.message_id == self.id)
|
|
||||||
.order_by(DatasetRetrieverResource.position.asc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message_files(self):
|
def message_files(self):
|
||||||
@ -1155,7 +1150,7 @@ class Message(db.Model): # type: ignore[name-defined]
|
|||||||
files.append(file)
|
files.append(file)
|
||||||
|
|
||||||
result = [
|
result = [
|
||||||
{"belongs_to": message_file.belongs_to, **file.to_dict()}
|
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||||
for (file, message_file) in zip(files, message_files)
|
for (file, message_file) in zip(files, message_files)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -510,7 +510,7 @@ class WorkflowRun(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionTriggeredFrom(Enum):
|
class WorkflowNodeExecutionTriggeredFrom(StrEnum):
|
||||||
"""
|
"""
|
||||||
Workflow Node Execution Triggered From Enum
|
Workflow Node Execution Triggered From Enum
|
||||||
"""
|
"""
|
||||||
@ -518,21 +518,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum):
|
|||||||
SINGLE_STEP = "single-step"
|
SINGLE_STEP = "single-step"
|
||||||
WORKFLOW_RUN = "workflow-run"
|
WORKFLOW_RUN = "workflow-run"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom":
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
class WorkflowNodeExecutionStatus(StrEnum):
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f"invalid workflow node execution triggered from value {value}")
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionStatus(Enum):
|
|
||||||
"""
|
"""
|
||||||
Workflow Node Execution Status Enum
|
Workflow Node Execution Status Enum
|
||||||
"""
|
"""
|
||||||
@ -543,19 +530,6 @@ class WorkflowNodeExecutionStatus(Enum):
|
|||||||
EXCEPTION = "exception"
|
EXCEPTION = "exception"
|
||||||
RETRY = "retry"
|
RETRY = "retry"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f"invalid workflow node execution status value {value}")
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecution(Base):
|
class WorkflowNodeExecution(Base):
|
||||||
"""
|
"""
|
||||||
@ -656,6 +630,7 @@ class WorkflowNodeExecution(Base):
|
|||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatedByRole(self.created_by_role)
|
||||||
|
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||||
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -663,6 +638,7 @@ class WorkflowNodeExecution(Base):
|
|||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
created_by_role = CreatedByRole(self.created_by_role)
|
created_by_role = CreatedByRole(self.created_by_role)
|
||||||
|
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
[virtualenvs]
|
|
||||||
in-project = true
|
|
||||||
create = true
|
|
||||||
prefer-active-python = true
|
|
@ -30,6 +30,7 @@ dependencies = [
|
|||||||
"gunicorn~=23.0.0",
|
"gunicorn~=23.0.0",
|
||||||
"httpx[socks]~=0.27.0",
|
"httpx[socks]~=0.27.0",
|
||||||
"jieba==0.42.1",
|
"jieba==0.42.1",
|
||||||
|
"json-repair>=0.41.1",
|
||||||
"langfuse~=2.51.3",
|
"langfuse~=2.51.3",
|
||||||
"langsmith~=0.1.77",
|
"langsmith~=0.1.77",
|
||||||
"mailchimp-transactional~=1.0.50",
|
"mailchimp-transactional~=1.0.50",
|
||||||
@ -104,6 +105,7 @@ dev = [
|
|||||||
"ruff~=0.11.5",
|
"ruff~=0.11.5",
|
||||||
"pytest~=8.3.2",
|
"pytest~=8.3.2",
|
||||||
"pytest-benchmark~=4.0.0",
|
"pytest-benchmark~=4.0.0",
|
||||||
|
"pytest-cov~=4.1.0",
|
||||||
"pytest-env~=1.1.3",
|
"pytest-env~=1.1.3",
|
||||||
"pytest-mock~=3.14.0",
|
"pytest-mock~=3.14.0",
|
||||||
"types-aiofiles~=24.1.0",
|
"types-aiofiles~=24.1.0",
|
||||||
@ -162,10 +164,7 @@ storage = [
|
|||||||
############################################################
|
############################################################
|
||||||
# [ Tools ] dependency group
|
# [ Tools ] dependency group
|
||||||
############################################################
|
############################################################
|
||||||
tools = [
|
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||||
"cloudscraper~=1.2.71",
|
|
||||||
"nltk~=3.9.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# [ VDB ] dependency group
|
# [ VDB ] dependency group
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
continue-on-collection-errors = true
|
continue-on-collection-errors = true
|
||||||
|
addopts = --cov=./api --cov-report=json --cov-report=xml
|
||||||
env =
|
env =
|
||||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||||
|
6
api/repositories/__init__.py
Normal file
6
api/repositories/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
Repository implementations for data access.
|
||||||
|
|
||||||
|
This package contains concrete implementations of the repository interfaces
|
||||||
|
defined in the core.repository package.
|
||||||
|
"""
|
87
api/repositories/repository_registry.py
Normal file
87
api/repositories/repository_registry.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Registry for repository implementations.
|
||||||
|
|
||||||
|
This module is responsible for registering factory functions with the repository factory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Storage type constants
|
||||||
|
STORAGE_TYPE_RDBMS = "rdbms"
|
||||||
|
STORAGE_TYPE_HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
def register_repositories() -> None:
|
||||||
|
"""
|
||||||
|
Register repository factory functions with the RepositoryFactory.
|
||||||
|
|
||||||
|
This function reads configuration settings to determine which repository
|
||||||
|
implementations to register.
|
||||||
|
"""
|
||||||
|
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
||||||
|
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
||||||
|
|
||||||
|
# Check storage type and register appropriate implementation
|
||||||
|
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
||||||
|
# Register SQLAlchemy implementation for RDBMS storage
|
||||||
|
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
||||||
|
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
||||||
|
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
||||||
|
# Hybrid storage is not yet implemented
|
||||||
|
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
||||||
|
else:
|
||||||
|
# Unknown storage type
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
||||||
|
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
||||||
|
|
||||||
|
This factory function creates a repository for the RDBMS storage type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Parameters for creating the repository, including:
|
||||||
|
- tenant_id: Required. The tenant ID for multi-tenancy.
|
||||||
|
- app_id: Optional. The application ID for filtering.
|
||||||
|
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
||||||
|
a new sessionmaker will be created using the global database engine.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A WorkflowNodeExecutionRepository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required parameters are missing
|
||||||
|
"""
|
||||||
|
# Extract required parameters
|
||||||
|
tenant_id = params.get("tenant_id")
|
||||||
|
if tenant_id is None:
|
||||||
|
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
||||||
|
|
||||||
|
# Extract optional parameters
|
||||||
|
app_id = params.get("app_id")
|
||||||
|
|
||||||
|
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
||||||
|
session_factory = params.get("session_factory")
|
||||||
|
if session_factory is None:
|
||||||
|
# Create a sessionmaker using the same engine as the global db session
|
||||||
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
|
|
||||||
|
# Create and return the repository
|
||||||
|
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||||
|
)
|
9
api/repositories/workflow_node_execution/__init__.py
Normal file
9
api/repositories/workflow_node_execution/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
WorkflowNodeExecution repository implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||||
|
]
|
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||||
|
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||||
|
|
||||||
|
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||||
|
Each method creates its own session, handles the transaction, and commits changes
|
||||||
|
to the database. This prevents long-running connections in the workflow core.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||||
|
tenant_id: Tenant ID for multi-tenancy
|
||||||
|
app_id: Optional app ID for filtering by application
|
||||||
|
"""
|
||||||
|
# If an engine is provided, create a sessionmaker from it
|
||||||
|
if isinstance(session_factory, Engine):
|
||||||
|
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||||
|
else:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
self._app_id = app_id
|
||||||
|
|
||||||
|
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Save a WorkflowNodeExecution instance and commit changes to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to save
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
# Ensure tenant_id is set
|
||||||
|
if not execution.tenant_id:
|
||||||
|
execution.tenant_id = self._tenant_id
|
||||||
|
|
||||||
|
# Set app_id if provided and not already set
|
||||||
|
if self._app_id and not execution.app_id:
|
||||||
|
execution.app_id = self._app_id
|
||||||
|
|
||||||
|
session.add(execution)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WorkflowNodeExecution instance if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
|
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||||
|
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
return session.scalar(stmt)
|
||||||
|
|
||||||
|
def get_by_workflow_run(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
order_config: Optional configuration for ordering results
|
||||||
|
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||||
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
|
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||||
|
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||||
|
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
# Apply ordering if provided
|
||||||
|
if order_config and order_config.order_by:
|
||||||
|
order_columns: list[UnaryExpression] = []
|
||||||
|
for field in order_config.order_by:
|
||||||
|
column = getattr(WorkflowNodeExecution, field, None)
|
||||||
|
if not column:
|
||||||
|
continue
|
||||||
|
if order_config.order_direction == "desc":
|
||||||
|
order_columns.append(desc(column))
|
||||||
|
else:
|
||||||
|
order_columns.append(asc(column))
|
||||||
|
|
||||||
|
if order_columns:
|
||||||
|
stmt = stmt.order_by(*order_columns)
|
||||||
|
|
||||||
|
return session.scalars(stmt).all()
|
||||||
|
|
||||||
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of running WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
|
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||||
|
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||||
|
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||||
|
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
return session.scalars(stmt).all()
|
||||||
|
|
||||||
|
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to update
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
# Ensure tenant_id is set
|
||||||
|
if not execution.tenant_id:
|
||||||
|
execution.tenant_id = self._tenant_id
|
||||||
|
|
||||||
|
# Set app_id if provided and not already set
|
||||||
|
if self._app_id and not execution.app_id:
|
||||||
|
execution.app_id = self._app_id
|
||||||
|
|
||||||
|
session.merge(execution)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
||||||
|
|
||||||
|
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||||
|
and app_id (if provided) associated with this repository instance.
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
result = session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
logger.info(
|
||||||
|
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||||
|
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||||
|
)
|
@ -407,10 +407,8 @@ class AccountService:
|
|||||||
|
|
||||||
raise PasswordResetRateLimitExceededError()
|
raise PasswordResetRateLimitExceededError()
|
||||||
|
|
||||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
code, token = cls.generate_reset_password_token(account_email, account)
|
||||||
token = TokenManager.generate_token(
|
|
||||||
account=account, email=email, token_type="reset_password", additional_data={"code": code}
|
|
||||||
)
|
|
||||||
send_reset_password_mail_task.delay(
|
send_reset_password_mail_task.delay(
|
||||||
language=language,
|
language=language,
|
||||||
to=account_email,
|
to=account_email,
|
||||||
@ -419,6 +417,22 @@ class AccountService:
|
|||||||
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
|
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_reset_password_token(
|
||||||
|
cls,
|
||||||
|
email: str,
|
||||||
|
account: Optional[Account] = None,
|
||||||
|
code: Optional[str] = None,
|
||||||
|
additional_data: dict[str, Any] = {},
|
||||||
|
):
|
||||||
|
if not code:
|
||||||
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||||
|
additional_data["code"] = code
|
||||||
|
token = TokenManager.generate_token(
|
||||||
|
account=account, email=email, token_type="reset_password", additional_data=additional_data
|
||||||
|
)
|
||||||
|
return code, token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def revoke_reset_password_token(cls, token: str):
|
def revoke_reset_password_token(cls, token: str):
|
||||||
TokenManager.revoke_token(token, "reset_password")
|
TokenManager.revoke_token(token, "reset_password")
|
||||||
|
@ -553,7 +553,7 @@ class DocumentService:
|
|||||||
{"id": "remove_extra_spaces", "enabled": True},
|
{"id": "remove_extra_spaces", "enabled": True},
|
||||||
{"id": "remove_urls_emails", "enabled": False},
|
{"id": "remove_urls_emails", "enabled": False},
|
||||||
],
|
],
|
||||||
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
|
"segmentation": {"delimiter": "\n", "max_tokens": 1024, "chunk_overlap": 50},
|
||||||
},
|
},
|
||||||
"limits": {
|
"limits": {
|
||||||
"indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH,
|
"indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH,
|
||||||
@ -2025,7 +2025,7 @@ class SegmentService:
|
|||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
document_id=document.id,
|
document_id=document.id,
|
||||||
segment_id=segment.id,
|
segment_id=segment.id,
|
||||||
position=max_position + 1,
|
position=max_position + 1 if max_position else 1,
|
||||||
index_node_id=index_node_id,
|
index_node_id=index_node_id,
|
||||||
index_node_hash=index_node_hash,
|
index_node_hash=index_node_hash,
|
||||||
content=content,
|
content=content,
|
||||||
@ -2175,7 +2175,13 @@ class SegmentService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_segments(
|
def get_segments(
|
||||||
cls, document_id: str, tenant_id: str, status_list: list[str] | None = None, keyword: str | None = None
|
cls,
|
||||||
|
document_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
status_list: list[str] | None = None,
|
||||||
|
keyword: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 20,
|
||||||
):
|
):
|
||||||
"""Get segments for a document with optional filtering."""
|
"""Get segments for a document with optional filtering."""
|
||||||
query = DocumentSegment.query.filter(
|
query = DocumentSegment.query.filter(
|
||||||
@ -2188,10 +2194,11 @@ class SegmentService:
|
|||||||
if keyword:
|
if keyword:
|
||||||
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
segments = query.order_by(DocumentSegment.position.asc()).all()
|
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
|
||||||
total = len(segments)
|
page=page, per_page=limit, max_per_page=100, error_out=False
|
||||||
|
)
|
||||||
|
|
||||||
return segments, total
|
return paginated_segments.items, paginated_segments.total
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_segment_by_id(
|
def update_segment_by_id(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from configs import dify_config
|
||||||
from core.helper import marketplace
|
from core.helper import marketplace
|
||||||
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
|
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
|
||||||
from core.plugin.manager.plugin import PluginInstallationManager
|
from core.plugin.manager.plugin import PluginInstallationManager
|
||||||
@ -111,6 +112,8 @@ class DependenciesAnalysisService:
|
|||||||
Generate the latest version of dependencies
|
Generate the latest version of dependencies
|
||||||
"""
|
"""
|
||||||
dependencies = list(set(dependencies))
|
dependencies = list(set(dependencies))
|
||||||
|
if not dify_config.MARKETPLACE_ENABLED:
|
||||||
|
return []
|
||||||
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
|
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
|
||||||
return [
|
return [
|
||||||
PluginDependency(
|
PluginDependency(
|
||||||
|
@ -2,13 +2,14 @@ import threading
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
|
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
WorkflowNodeExecution,
|
WorkflowNodeExecution,
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,17 +128,17 @@ class WorkflowRunService:
|
|||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
node_executions = (
|
# Use the repository to get the node executions
|
||||||
db.session.query(WorkflowNodeExecution)
|
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
.filter(
|
params={
|
||||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
"tenant_id": app_model.tenant_id,
|
||||||
WorkflowNodeExecution.app_id == app_model.id,
|
"app_id": app_model.id,
|
||||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
"session_factory": db.session.get_bind,
|
||||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
}
|
||||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
|
||||||
)
|
|
||||||
.order_by(WorkflowNodeExecution.index.desc())
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return node_executions
|
# Use the repository to get the node executions with ordering
|
||||||
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
|
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
|
||||||
|
|
||||||
|
return list(node_executions)
|
||||||
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
@ -282,8 +283,15 @@ class WorkflowService:
|
|||||||
workflow_node_execution.created_by = account.id
|
workflow_node_execution.created_by = account.id
|
||||||
workflow_node_execution.workflow_id = draft_workflow.id
|
workflow_node_execution.workflow_id = draft_workflow.id
|
||||||
|
|
||||||
db.session.add(workflow_node_execution)
|
# Use the repository to save the workflow node execution
|
||||||
db.session.commit()
|
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
params={
|
||||||
|
"tenant_id": app_model.tenant_id,
|
||||||
|
"app_id": app_model.id,
|
||||||
|
"session_factory": db.session.get_bind,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
repository.save(workflow_node_execution)
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
|
|||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from core.repository import RepositoryFactory
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import AppDatasetJoin
|
from models.dataset import AppDatasetJoin
|
||||||
from models.model import (
|
from models.model import (
|
||||||
@ -30,7 +31,7 @@ from models.model import (
|
|||||||
)
|
)
|
||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.web import PinnedConversation, SavedMessage
|
from models.web import PinnedConversation, SavedMessage
|
||||||
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
|
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
@shared_task(queue="app_deletion", bind=True, max_retries=3)
|
||||||
@ -187,17 +188,19 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
|||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||||
def del_workflow_node_execution(workflow_node_execution_id: str):
|
# Create a repository instance for WorkflowNodeExecution
|
||||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
|
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||||
synchronize_session=False
|
params={
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"session_factory": db.session.get_bind,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
_delete_records(
|
# Use the clear method to delete all records for this tenant_id and app_id
|
||||||
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
repository.clear()
|
||||||
{"tenant_id": tenant_id, "app_id": app_id},
|
|
||||||
del_workflow_node_execution,
|
logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
|
||||||
"workflow node execution",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||||
|
|
||||||
|
ToolCall = AssistantPromptMessage.ToolCall
|
||||||
|
|
||||||
|
# CASE 1: Single tool call
|
||||||
|
INPUTS_CASE_1 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_1 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||||
|
INPUTS_CASE_2 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_2 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||||
|
INPUTS_CASE_3 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_3 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 4: Tool call sequences with no IDs
|
||||||
|
INPUTS_CASE_4 = [
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_4 = [
|
||||||
|
ToolCall(
|
||||||
|
id="RANDOM_ID_1",
|
||||||
|
type="function",
|
||||||
|
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="RANDOM_ID_2",
|
||||||
|
type="function",
|
||||||
|
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||||
|
actual = []
|
||||||
|
_increase_tool_call(inputs, actual)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test__increase_tool_call():
|
||||||
|
# case 1:
|
||||||
|
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||||
|
|
||||||
|
# case 2:
|
||||||
|
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||||
|
|
||||||
|
# case 3:
|
||||||
|
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||||
|
|
||||||
|
# case 4:
|
||||||
|
mock_id_generator = MagicMock()
|
||||||
|
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||||
|
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||||
|
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
@ -1,14 +1,20 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
|
NodeRunFailedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
|
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
|
|
||||||
class ContinueOnErrorTestHelper:
|
class ContinueOnErrorTestHelper:
|
||||||
@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
|
|||||||
"edges": FAIL_BRANCH_EDGES[:-1],
|
"edges": FAIL_BRANCH_EDGES[:-1],
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
{
|
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
|
||||||
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
|
||||||
"id": "success",
|
|
||||||
},
|
|
||||||
ContinueOnErrorTestHelper.get_http_node(),
|
ContinueOnErrorTestHelper.get_http_node(),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
|
|||||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
||||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_output_with_fail_branch_continue_on_error():
|
||||||
|
"""Test stream output with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_llm_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
|
||||||
|
def llm_generator(self):
|
||||||
|
contents = ["hi", "bye", "good morning"]
|
||||||
|
|
||||||
|
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
|
||||||
|
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={},
|
||||||
|
process_data={},
|
||||||
|
outputs={},
|
||||||
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: 1,
|
||||||
|
NodeRunMetadataKey.TOTAL_PRICE: 1,
|
||||||
|
NodeRunMetadataKey.CURRENCY: "USD",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(LLMNode, "_run", new=llm_generator):
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
|
||||||
|
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)
|
||||||
|
198
api/tests/unit_tests/factories/test_build_from_mapping.py
Normal file
198
api/tests/unit_tests/factories/test_build_from_mapping.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from factories.file_factory import (
|
||||||
|
File,
|
||||||
|
FileTransferMethod,
|
||||||
|
FileType,
|
||||||
|
FileUploadConfig,
|
||||||
|
build_from_mapping,
|
||||||
|
)
|
||||||
|
from models import ToolFile, UploadFile
|
||||||
|
|
||||||
|
# Test Data
|
||||||
|
TEST_TENANT_ID = "test_tenant_id"
|
||||||
|
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||||
|
TEST_TOOL_FILE_ID = str(uuid.uuid4())
|
||||||
|
TEST_REMOTE_URL = "http://example.com/test.jpg"
|
||||||
|
|
||||||
|
# Test Config
|
||||||
|
TEST_CONFIG = FileUploadConfig(
|
||||||
|
allowed_file_types=["image", "document"],
|
||||||
|
allowed_file_extensions=[".jpg", ".pdf"],
|
||||||
|
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
|
||||||
|
number_limits=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_upload_file():
|
||||||
|
mock = MagicMock(spec=UploadFile)
|
||||||
|
mock.id = TEST_UPLOAD_FILE_ID
|
||||||
|
mock.tenant_id = TEST_TENANT_ID
|
||||||
|
mock.name = "test.jpg"
|
||||||
|
mock.extension = "jpg"
|
||||||
|
mock.mime_type = "image/jpeg"
|
||||||
|
mock.source_url = TEST_REMOTE_URL
|
||||||
|
mock.size = 1024
|
||||||
|
mock.key = "test_key"
|
||||||
|
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
|
||||||
|
yield m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_file():
|
||||||
|
mock = MagicMock(spec=ToolFile)
|
||||||
|
mock.id = TEST_TOOL_FILE_ID
|
||||||
|
mock.tenant_id = TEST_TENANT_ID
|
||||||
|
mock.name = "tool_file.pdf"
|
||||||
|
mock.file_key = "tool_file.pdf"
|
||||||
|
mock.mimetype = "application/pdf"
|
||||||
|
mock.original_url = "http://example.com/tool.pdf"
|
||||||
|
mock.size = 2048
|
||||||
|
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.filter.return_value.first.return_value = mock
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_http_head():
|
||||||
|
def _mock_response(filename, size, content_type):
|
||||||
|
return Response(
|
||||||
|
status_code=200,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
|
"Content-Length": str(size),
|
||||||
|
"Content-Type": content_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
|
||||||
|
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||||
|
yield mock_head
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def local_file_mapping(file_type="image"):
|
||||||
|
return {
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||||
|
"type": file_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tool_file_mapping(file_type="document"):
|
||||||
|
return {
|
||||||
|
"transfer_method": "tool_file",
|
||||||
|
"tool_file_id": TEST_TOOL_FILE_ID,
|
||||||
|
"type": file_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
def test_build_from_mapping_backward_compatibility(mock_upload_file):
|
||||||
|
mapping = local_file_mapping(file_type="image")
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
assert isinstance(file, File)
|
||||||
|
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.related_id == TEST_UPLOAD_FILE_ID
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("image", True, None),
|
||||||
|
("document", False, "Detected file type does not match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
|
||||||
|
mapping = local_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
assert file.type == FileType(file_type)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("document", True, None),
|
||||||
|
("image", False, "Detected file type does not match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
|
||||||
|
"""Strict type validation for tool_file."""
|
||||||
|
mapping = tool_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
assert file.type == FileType(file_type)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_from_remote_url(mock_http_head):
|
||||||
|
mapping = {
|
||||||
|
"transfer_method": "remote_url",
|
||||||
|
"url": TEST_REMOTE_URL,
|
||||||
|
"type": "image",
|
||||||
|
}
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
assert file.transfer_method == FileTransferMethod.REMOTE_URL
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.filename == "remote_test.jpg"
|
||||||
|
assert file.size == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_file_not_found():
|
||||||
|
"""Test ToolFile not found in database."""
|
||||||
|
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.filter.return_value.first.return_value = None
|
||||||
|
mapping = tool_file_mapping()
|
||||||
|
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_file_not_found():
|
||||||
|
"""Test UploadFile not found in database."""
|
||||||
|
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||||
|
mapping = local_file_mapping()
|
||||||
|
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_without_type_specification(mock_upload_file):
|
||||||
|
"""Test the situation where no file type is specified"""
|
||||||
|
mapping = {
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||||
|
# leave out the type
|
||||||
|
}
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
# It should automatically infer the type as "image" based on the file extension
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("image", True, None),
|
||||||
|
("video", False, "File validation failed"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
|
||||||
|
"""Test the validation of files and configurations"""
|
||||||
|
mapping = local_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||||
|
assert file is not None
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
3
api/tests/unit_tests/repositories/__init__.py
Normal file
3
api/tests/unit_tests/repositories/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for repositories.
|
||||||
|
"""
|
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for workflow_node_execution repositories.
|
||||||
|
"""
|
@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||||
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session():
|
||||||
|
"""Create a mock SQLAlchemy session."""
|
||||||
|
session = MagicMock(spec=Session)
|
||||||
|
# Configure the session to be used as a context manager
|
||||||
|
session.__enter__ = MagicMock(return_value=session)
|
||||||
|
session.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Configure the session factory to return the session
|
||||||
|
session_factory = MagicMock(spec=sessionmaker)
|
||||||
|
session_factory.return_value = session
|
||||||
|
return session, session_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repository(session):
|
||||||
|
"""Create a repository instance with test data."""
|
||||||
|
_, session_factory = session
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
app_id = "test-app"
|
||||||
|
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save(repository, session):
|
||||||
|
"""Test save method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Create a mock execution
|
||||||
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
execution.tenant_id = None
|
||||||
|
execution.app_id = None
|
||||||
|
|
||||||
|
# Call save method
|
||||||
|
repository.save(execution)
|
||||||
|
|
||||||
|
# Assert tenant_id and app_id are set
|
||||||
|
assert execution.tenant_id == repository._tenant_id
|
||||||
|
assert execution.app_id == repository._app_id
|
||||||
|
|
||||||
|
# Assert session.add was called
|
||||||
|
session_obj.add.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_with_existing_tenant_id(repository, session):
|
||||||
|
"""Test save method with existing tenant_id."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Create a mock execution with existing tenant_id
|
||||||
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
execution.tenant_id = "existing-tenant"
|
||||||
|
execution.app_id = None
|
||||||
|
|
||||||
|
# Call save method
|
||||||
|
repository.save(execution)
|
||||||
|
|
||||||
|
# Assert tenant_id is not changed and app_id is set
|
||||||
|
assert execution.tenant_id == "existing-tenant"
|
||||||
|
assert execution.app_id == repository._app_id
|
||||||
|
|
||||||
|
# Assert session.add was called
|
||||||
|
session_obj.add.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_by_node_execution_id method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_by_workflow_run method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
mock_stmt.order_by.return_value = mock_stmt
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
|
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_running_executions method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
result = repository.get_running_executions("test-workflow-run-id")
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_update(repository, session):
|
||||||
|
"""Test update method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Create a mock execution
|
||||||
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
execution.tenant_id = None
|
||||||
|
execution.app_id = None
|
||||||
|
|
||||||
|
# Call update method
|
||||||
|
repository.update(execution)
|
||||||
|
|
||||||
|
# Assert tenant_id and app_id are set
|
||||||
|
assert execution.tenant_id == repository._tenant_id
|
||||||
|
assert execution.app_id == repository._app_id
|
||||||
|
|
||||||
|
# Assert session.merge was called
|
||||||
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test clear method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_delete.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
|
||||||
|
# Mock the execute result with rowcount
|
||||||
|
mock_result = mocker.MagicMock()
|
||||||
|
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||||
|
session_obj.execute.return_value = mock_result
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
repository.clear()
|
||||||
|
|
||||||
|
# Assert delete was called with correct parameters
|
||||||
|
mock_delete.assert_called_once_with(WorkflowNodeExecution)
|
||||||
|
mock_stmt.where.assert_called()
|
||||||
|
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||||
|
session_obj.commit.assert_called_once()
|
63
api/uv.lock
generated
63
api/uv.lock
generated
@ -1,5 +1,4 @@
|
|||||||
version = 1
|
version = 1
|
||||||
revision = 1
|
|
||||||
requires-python = ">=3.11, <3.13"
|
requires-python = ">=3.11, <3.13"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
|
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
|
||||||
@ -1012,6 +1011,11 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b7/00/14b00a0748e9eda26e97be07a63cc911108844004687321ddcc213be956c/coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3", size = 204347 },
|
{ url = "https://files.pythonhosted.org/packages/b7/00/14b00a0748e9eda26e97be07a63cc911108844004687321ddcc213be956c/coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3", size = 204347 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
toml = [
|
||||||
|
{ name = "tomli", marker = "python_full_version <= '3.11'" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crc32c"
|
name = "crc32c"
|
||||||
version = "2.7.1"
|
version = "2.7.1"
|
||||||
@ -1173,6 +1177,7 @@ dependencies = [
|
|||||||
{ name = "gunicorn" },
|
{ name = "gunicorn" },
|
||||||
{ name = "httpx", extra = ["socks"] },
|
{ name = "httpx", extra = ["socks"] },
|
||||||
{ name = "jieba" },
|
{ name = "jieba" },
|
||||||
|
{ name = "json-repair" },
|
||||||
{ name = "langfuse" },
|
{ name = "langfuse" },
|
||||||
{ name = "langsmith" },
|
{ name = "langsmith" },
|
||||||
{ name = "mailchimp-transactional" },
|
{ name = "mailchimp-transactional" },
|
||||||
@ -1234,6 +1239,7 @@ dev = [
|
|||||||
{ name = "mypy" },
|
{ name = "mypy" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-benchmark" },
|
{ name = "pytest-benchmark" },
|
||||||
|
{ name = "pytest-cov" },
|
||||||
{ name = "pytest-env" },
|
{ name = "pytest-env" },
|
||||||
{ name = "pytest-mock" },
|
{ name = "pytest-mock" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
@ -1340,6 +1346,7 @@ requires-dist = [
|
|||||||
{ name = "gunicorn", specifier = "~=23.0.0" },
|
{ name = "gunicorn", specifier = "~=23.0.0" },
|
||||||
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
|
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
|
||||||
{ name = "jieba", specifier = "==0.42.1" },
|
{ name = "jieba", specifier = "==0.42.1" },
|
||||||
|
{ name = "json-repair", specifier = ">=0.41.1" },
|
||||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||||
{ name = "langsmith", specifier = "~=0.1.77" },
|
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||||
{ name = "mailchimp-transactional", specifier = "~=1.0.50" },
|
{ name = "mailchimp-transactional", specifier = "~=1.0.50" },
|
||||||
@ -1401,6 +1408,7 @@ dev = [
|
|||||||
{ name = "mypy", specifier = "~=1.15.0" },
|
{ name = "mypy", specifier = "~=1.15.0" },
|
||||||
{ name = "pytest", specifier = "~=8.3.2" },
|
{ name = "pytest", specifier = "~=8.3.2" },
|
||||||
{ name = "pytest-benchmark", specifier = "~=4.0.0" },
|
{ name = "pytest-benchmark", specifier = "~=4.0.0" },
|
||||||
|
{ name = "pytest-cov", specifier = "~=4.1.0" },
|
||||||
{ name = "pytest-env", specifier = "~=1.1.3" },
|
{ name = "pytest-env", specifier = "~=1.1.3" },
|
||||||
{ name = "pytest-mock", specifier = "~=3.14.0" },
|
{ name = "pytest-mock", specifier = "~=3.14.0" },
|
||||||
{ name = "ruff", specifier = "~=0.11.5" },
|
{ name = "ruff", specifier = "~=0.11.5" },
|
||||||
@ -2517,6 +2525,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
|
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "json-repair"
|
||||||
|
version = "0.41.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jsonpath-python"
|
name = "jsonpath-python"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
@ -4067,6 +4084,8 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
|
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
|
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
|
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -4333,6 +4352,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/4d/a1/3b70862b5b3f830f0422844f25a823d0470739d994466be9dbbbb414d85a/pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6", size = 43951 },
|
{ url = "https://files.pythonhosted.org/packages/4d/a1/3b70862b5b3f830f0422844f25a823d0470739d994466be9dbbbb414d85a/pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6", size = 43951 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-cov"
|
||||||
|
version = "4.1.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "coverage", extra = ["toml"] },
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/7a/15/da3df99fd551507694a9b01f512a2f6cf1254f33601605843c3775f39460/pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", size = 63245 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", size = 21949 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-env"
|
name = "pytest-env"
|
||||||
version = "1.1.5"
|
version = "1.1.5"
|
||||||
@ -5235,6 +5267,35 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 },
|
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tomli"
|
||||||
|
version = "2.2.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tos"
|
name = "tos"
|
||||||
version = "2.7.2"
|
version = "2.7.2"
|
||||||
|
@ -174,6 +174,12 @@ CELERY_MIN_WORKERS=
|
|||||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Datasource Configuration
|
||||||
|
# --------------------------------
|
||||||
|
ENABLE_WEBSITE_JINAREADER=true
|
||||||
|
ENABLE_WEBSITE_FIRECRAWL=true
|
||||||
|
ENABLE_WEBSITE_WATERCRAWL=true
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# Database Configuration
|
# Database Configuration
|
||||||
@ -404,6 +410,7 @@ MILVUS_TOKEN=
|
|||||||
MILVUS_USER=
|
MILVUS_USER=
|
||||||
MILVUS_PASSWORD=
|
MILVUS_PASSWORD=
|
||||||
MILVUS_ENABLE_HYBRID_SEARCH=False
|
MILVUS_ENABLE_HYBRID_SEARCH=False
|
||||||
|
MILVUS_ANALYZER_PARAMS=
|
||||||
|
|
||||||
# MyScale configuration, only available when VECTOR_STORE is `myscale`
|
# MyScale configuration, only available when VECTOR_STORE is `myscale`
|
||||||
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
|
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
|
||||||
@ -737,6 +744,12 @@ MAX_VARIABLE_SIZE=204800
|
|||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||||
|
|
||||||
|
# Workflow storage configuration
|
||||||
|
# Options: rdbms, hybrid
|
||||||
|
# rdbms: Use only the relational database (default)
|
||||||
|
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||||
|
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||||
|
|
||||||
# HTTP request node in workflow configuration
|
# HTTP request node in workflow configuration
|
||||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||||
|
@ -17,8 +17,10 @@ services:
|
|||||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
db:
|
||||||
- redis
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_started
|
||||||
volumes:
|
volumes:
|
||||||
# Mount the storage directory to the container, for storing user files.
|
# Mount the storage directory to the container, for storing user files.
|
||||||
- ./volumes/app/storage:/app/api/storage
|
- ./volumes/app/storage:/app/api/storage
|
||||||
@ -42,8 +44,10 @@ services:
|
|||||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
db:
|
||||||
- redis
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_started
|
||||||
volumes:
|
volumes:
|
||||||
# Mount the storage directory to the container, for storing user files.
|
# Mount the storage directory to the container, for storing user files.
|
||||||
- ./volumes/app/storage:/app/api/storage
|
- ./volumes/app/storage:/app/api/storage
|
||||||
@ -71,7 +75,9 @@ services:
|
|||||||
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
||||||
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
||||||
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5}
|
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5}
|
||||||
|
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
||||||
|
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
||||||
|
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
||||||
# The postgres database.
|
# The postgres database.
|
||||||
db:
|
db:
|
||||||
image: postgres:15-alpine
|
image: postgres:15-alpine
|
||||||
@ -124,6 +130,7 @@ services:
|
|||||||
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
||||||
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
||||||
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
||||||
|
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/sandbox/dependencies:/dependencies
|
- ./volumes/sandbox/dependencies:/dependencies
|
||||||
- ./volumes/sandbox/conf:/conf
|
- ./volumes/sandbox/conf:/conf
|
||||||
|
@ -60,6 +60,7 @@ services:
|
|||||||
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
||||||
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
||||||
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
||||||
|
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/sandbox/dependencies:/dependencies
|
- ./volumes/sandbox/dependencies:/dependencies
|
||||||
- ./volumes/sandbox/conf:/conf
|
- ./volumes/sandbox/conf:/conf
|
||||||
|
@ -43,6 +43,9 @@ x-shared-env: &shared-api-worker-env
|
|||||||
CELERY_MIN_WORKERS: ${CELERY_MIN_WORKERS:-}
|
CELERY_MIN_WORKERS: ${CELERY_MIN_WORKERS:-}
|
||||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT: ${API_TOOL_DEFAULT_CONNECT_TIMEOUT:-10}
|
API_TOOL_DEFAULT_CONNECT_TIMEOUT: ${API_TOOL_DEFAULT_CONNECT_TIMEOUT:-10}
|
||||||
API_TOOL_DEFAULT_READ_TIMEOUT: ${API_TOOL_DEFAULT_READ_TIMEOUT:-60}
|
API_TOOL_DEFAULT_READ_TIMEOUT: ${API_TOOL_DEFAULT_READ_TIMEOUT:-60}
|
||||||
|
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
||||||
|
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
||||||
|
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
||||||
DB_USERNAME: ${DB_USERNAME:-postgres}
|
DB_USERNAME: ${DB_USERNAME:-postgres}
|
||||||
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||||
DB_HOST: ${DB_HOST:-db}
|
DB_HOST: ${DB_HOST:-db}
|
||||||
@ -139,6 +142,7 @@ x-shared-env: &shared-api-worker-env
|
|||||||
MILVUS_USER: ${MILVUS_USER:-}
|
MILVUS_USER: ${MILVUS_USER:-}
|
||||||
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-}
|
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-}
|
||||||
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
|
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
|
||||||
|
MILVUS_ANALYZER_PARAMS: ${MILVUS_ANALYZER_PARAMS:-}
|
||||||
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
|
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
|
||||||
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
|
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
|
||||||
MYSCALE_USER: ${MYSCALE_USER:-default}
|
MYSCALE_USER: ${MYSCALE_USER:-default}
|
||||||
@ -323,6 +327,7 @@ x-shared-env: &shared-api-worker-env
|
|||||||
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
|
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
|
||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
|
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
|
||||||
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
|
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
|
||||||
|
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
|
||||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
|
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
|
||||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
|
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
|
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
|
||||||
@ -485,8 +490,10 @@ services:
|
|||||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
db:
|
||||||
- redis
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_started
|
||||||
volumes:
|
volumes:
|
||||||
# Mount the storage directory to the container, for storing user files.
|
# Mount the storage directory to the container, for storing user files.
|
||||||
- ./volumes/app/storage:/app/api/storage
|
- ./volumes/app/storage:/app/api/storage
|
||||||
@ -510,8 +517,10 @@ services:
|
|||||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
db:
|
||||||
- redis
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_started
|
||||||
volumes:
|
volumes:
|
||||||
# Mount the storage directory to the container, for storing user files.
|
# Mount the storage directory to the container, for storing user files.
|
||||||
- ./volumes/app/storage:/app/api/storage
|
- ./volumes/app/storage:/app/api/storage
|
||||||
@ -539,7 +548,9 @@ services:
|
|||||||
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
||||||
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
||||||
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5}
|
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5}
|
||||||
|
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
|
||||||
|
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
|
||||||
|
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
|
||||||
# The postgres database.
|
# The postgres database.
|
||||||
db:
|
db:
|
||||||
image: postgres:15-alpine
|
image: postgres:15-alpine
|
||||||
@ -592,6 +603,7 @@ services:
|
|||||||
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
|
||||||
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
|
||||||
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
|
||||||
|
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/sandbox/dependencies:/dependencies
|
- ./volumes/sandbox/dependencies:/dependencies
|
||||||
- ./volumes/sandbox/conf:/conf
|
- ./volumes/sandbox/conf:/conf
|
||||||
|
@ -49,3 +49,8 @@ NEXT_PUBLIC_MAX_PARALLEL_LIMIT=10
|
|||||||
|
|
||||||
# The maximum number of iterations for agent setting
|
# The maximum number of iterations for agent setting
|
||||||
NEXT_PUBLIC_MAX_ITERATIONS_NUM=5
|
NEXT_PUBLIC_MAX_ITERATIONS_NUM=5
|
||||||
|
|
||||||
|
NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true
|
||||||
|
NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true
|
||||||
|
NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
|
|||||||
### Run by source code
|
### Run by source code
|
||||||
|
|
||||||
Before starting the web frontend service, please make sure the following environment is ready.
|
Before starting the web frontend service, please make sure the following environment is ready.
|
||||||
- [Node.js](https://nodejs.org) >= v18.x
|
- [Node.js](https://nodejs.org) >= v22.11.x
|
||||||
- [pnpm](https://pnpm.io) v10.x
|
- [pnpm](https://pnpm.io) v10.x
|
||||||
|
|
||||||
First, install the dependencies:
|
First, install the dependencies:
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import Workflow from '@/app/components/workflow'
|
import WorkflowApp from '@/app/components/workflow-app'
|
||||||
|
|
||||||
const Page = () => {
|
const Page = () => {
|
||||||
return (
|
return (
|
||||||
<div className='h-full w-full overflow-x-auto'>
|
<div className='h-full w-full overflow-x-auto'>
|
||||||
<Workflow />
|
<WorkflowApp />
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -557,7 +557,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
|
|
||||||
<Heading
|
<Heading
|
||||||
url='/datasets/{dataset_id}'
|
url='/datasets/{dataset_id}'
|
||||||
method='POST'
|
method='PATCH'
|
||||||
title='Update knowledge base'
|
title='Update knowledge base'
|
||||||
name='#update_dataset'
|
name='#update_dataset'
|
||||||
/>
|
/>
|
||||||
@ -585,8 +585,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
<Property name='embedding_model' type='string' key='embedding_model'>
|
<Property name='embedding_model' type='string' key='embedding_model'>
|
||||||
Specified embedding model, corresponding to the model field(Optional)
|
Specified embedding model, corresponding to the model field(Optional)
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='retrieval_model' type='string' key='retrieval_model'>
|
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||||
Specified retrieval model, corresponding to the model field(Optional)
|
Retrieval model (optional, if not filled, it will be recalled according to the default method)
|
||||||
|
- <code>search_method</code> (text) Search method: One of the following four keywords is required
|
||||||
|
- <code>keyword_search</code> Keyword search
|
||||||
|
- <code>semantic_search</code> Semantic search
|
||||||
|
- <code>full_text_search</code> Full-text search
|
||||||
|
- <code>hybrid_search</code> Hybrid search
|
||||||
|
- <code>reranking_enable</code> (bool) Whether to enable reranking, required if the search mode is semantic_search or hybrid_search (optional)
|
||||||
|
- <code>reranking_mode</code> (object) Rerank model configuration, required if reranking is enabled
|
||||||
|
- <code>reranking_provider_name</code> (string) Rerank model provider
|
||||||
|
- <code>reranking_model_name</code> (string) Rerank model name
|
||||||
|
- <code>weights</code> (float) Semantic search weight setting in hybrid search mode
|
||||||
|
- <code>top_k</code> (integer) Number of results to return (optional)
|
||||||
|
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
|
||||||
|
- <code>score_threshold</code> (float) Score threshold
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='partial_member_list' type='array' key='partial_member_list'>
|
<Property name='partial_member_list' type='array' key='partial_member_list'>
|
||||||
Partial member list(Optional)
|
Partial member list(Optional)
|
||||||
@ -596,16 +609,56 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
<Col sticky>
|
<Col sticky>
|
||||||
<CodeGroup
|
<CodeGroup
|
||||||
title="Request"
|
title="Request"
|
||||||
tag="POST"
|
tag="PATCH"
|
||||||
label="/datasets/{dataset_id}"
|
label="/datasets/{dataset_id}"
|
||||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me", "embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}' `}
|
targetCode={`curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{
|
||||||
|
"name": "Test Knowledge Base",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"permission": "only_me",
|
||||||
|
"embedding_model_provider": "zhipuai",
|
||||||
|
"embedding_model": "embedding-3",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 1,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
},
|
||||||
|
"partial_member_list": []
|
||||||
|
}'
|
||||||
|
`}
|
||||||
>
|
>
|
||||||
```bash {{ title: 'cURL' }}
|
```bash {{ title: 'cURL' }}
|
||||||
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \
|
curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \
|
||||||
--header 'Authorization: Bearer {api_key}' \
|
--header 'Authorization: Bearer {api_key}' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me",\
|
--data-raw '{
|
||||||
"embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}'
|
"name": "Test Knowledge Base",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"permission": "only_me",
|
||||||
|
"embedding_model_provider": "zhipuai",
|
||||||
|
"embedding_model": "embedding-3",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 1,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
},
|
||||||
|
"partial_member_list": []
|
||||||
|
}'
|
||||||
```
|
```
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
<CodeGroup title="Response">
|
<CodeGroup title="Response">
|
||||||
|
@ -94,6 +94,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
- <code>semantic_search</code> 语义检索
|
- <code>semantic_search</code> 语义检索
|
||||||
- <code>full_text_search</code> 全文检索
|
- <code>full_text_search</code> 全文检索
|
||||||
- <code>reranking_enable</code> (bool) 是否开启rerank
|
- <code>reranking_enable</code> (bool) 是否开启rerank
|
||||||
|
- <code>reranking_mode</code> (String) 混合检索
|
||||||
|
- <code>weighted_score</code> 权重设置
|
||||||
|
- <code>reranking_model</code> Rerank 模型
|
||||||
- <code>reranking_model</code> (object) Rerank 模型配置
|
- <code>reranking_model</code> (object) Rerank 模型配置
|
||||||
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
|
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
|
||||||
- <code>reranking_model_name</code> (string) Rerank 模型的名称
|
- <code>reranking_model_name</code> (string) Rerank 模型的名称
|
||||||
@ -557,7 +560,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
|
|
||||||
<Heading
|
<Heading
|
||||||
url='/datasets/{dataset_id}'
|
url='/datasets/{dataset_id}'
|
||||||
method='POST'
|
method='PATCH'
|
||||||
title='修改知识库详情'
|
title='修改知识库详情'
|
||||||
name='#update_dataset'
|
name='#update_dataset'
|
||||||
/>
|
/>
|
||||||
@ -589,8 +592,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
<Property name='embedding_model' type='string' key='embedding_model'>
|
<Property name='embedding_model' type='string' key='embedding_model'>
|
||||||
嵌入模型(选填)
|
嵌入模型(选填)
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='retrieval_model' type='string' key='retrieval_model'>
|
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||||
检索模型(选填)
|
检索参数(选填,如不填,按照默认方式召回)
|
||||||
|
- <code>search_method</code> (text) 检索方法:以下四个关键字之一,必填
|
||||||
|
- <code>keyword_search</code> 关键字检索
|
||||||
|
- <code>semantic_search</code> 语义检索
|
||||||
|
- <code>full_text_search</code> 全文检索
|
||||||
|
- <code>hybrid_search</code> 混合检索
|
||||||
|
- <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为 semantic_search 模式或者 hybrid_search 则传值
|
||||||
|
- <code>reranking_mode</code> (object) Rerank 模型配置,非必填,如果启用了 reranking 则传值
|
||||||
|
- <code>reranking_provider_name</code> (string) Rerank 模型提供商
|
||||||
|
- <code>reranking_model_name</code> (string) Rerank 模型名称
|
||||||
|
- <code>weights</code> (float) 混合检索模式下语意检索的权重设置
|
||||||
|
- <code>top_k</code> (integer) 返回结果数量,非必填
|
||||||
|
- <code>score_threshold_enabled</code> (bool) 是否开启 score 阈值
|
||||||
|
- <code>score_threshold</code> (float) Score 阈值
|
||||||
</Property>
|
</Property>
|
||||||
<Property name='partial_member_list' type='array' key='partial_member_list'>
|
<Property name='partial_member_list' type='array' key='partial_member_list'>
|
||||||
部分团队成员 ID 列表(选填)
|
部分团队成员 ID 列表(选填)
|
||||||
@ -600,16 +616,56 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
<Col sticky>
|
<Col sticky>
|
||||||
<CodeGroup
|
<CodeGroup
|
||||||
title="Request"
|
title="Request"
|
||||||
tag="POST"
|
tag="PATCH"
|
||||||
label="/datasets/{dataset_id}"
|
label="/datasets/{dataset_id}"
|
||||||
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me", "embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}' `}
|
targetCode={`curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{
|
||||||
|
"name": "Test Knowledge Base",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"permission": "only_me",
|
||||||
|
"embedding_model_provider": "zhipuai",
|
||||||
|
"embedding_model": "embedding-3",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 1,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
},
|
||||||
|
"partial_member_list": []
|
||||||
|
}'
|
||||||
|
`}
|
||||||
>
|
>
|
||||||
```bash {{ title: 'cURL' }}
|
```bash {{ title: 'cURL' }}
|
||||||
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \
|
curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \
|
||||||
--header 'Authorization: Bearer {api_key}' \
|
--header 'Authorization: Bearer {api_key}' \
|
||||||
--header 'Content-Type: application/json' \
|
--header 'Content-Type: application/json' \
|
||||||
--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me",\
|
--data-raw '{
|
||||||
"embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}'
|
"name": "Test Knowledge Base",
|
||||||
|
"indexing_technique": "high_quality",
|
||||||
|
"permission": "only_me",
|
||||||
|
"embedding_model_provider": "zhipuai",
|
||||||
|
"embedding_model": "embedding-3",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 1,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
},
|
||||||
|
"partial_member_list": []
|
||||||
|
}'
|
||||||
```
|
```
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
<CodeGroup title="Response">
|
<CodeGroup title="Response">
|
||||||
@ -1764,7 +1820,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||||||
</Property>
|
</Property>
|
||||||
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||||
检索参数(选填,如不填,按照默认方式召回)
|
检索参数(选填,如不填,按照默认方式召回)
|
||||||
- <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填
|
- <code>search_method</code> (text) 检索方法:以下四个关键字之一,必填
|
||||||
- <code>keyword_search</code> 关键字检索
|
- <code>keyword_search</code> 关键字检索
|
||||||
- <code>semantic_search</code> 语义检索
|
- <code>semantic_search</code> 语义检索
|
||||||
- <code>full_text_search</code> 全文检索
|
- <code>full_text_search</code> 全文检索
|
||||||
|
@ -0,0 +1,82 @@
|
|||||||
|
import { fireEvent, render, screen } from '@testing-library/react'
|
||||||
|
import ConfigSelect from './index'
|
||||||
|
|
||||||
|
jest.mock('react-sortablejs', () => ({
|
||||||
|
ReactSortable: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||||
|
}))
|
||||||
|
|
||||||
|
jest.mock('react-i18next', () => ({
|
||||||
|
useTranslation: () => ({
|
||||||
|
t: (key: string) => key,
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('ConfigSelect Component', () => {
|
||||||
|
const defaultProps = {
|
||||||
|
options: ['Option 1', 'Option 2'],
|
||||||
|
onChange: jest.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders all options', () => {
|
||||||
|
render(<ConfigSelect {...defaultProps} />)
|
||||||
|
|
||||||
|
defaultProps.options.forEach((option) => {
|
||||||
|
expect(screen.getByDisplayValue(option)).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders add button', () => {
|
||||||
|
render(<ConfigSelect {...defaultProps} />)
|
||||||
|
|
||||||
|
expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles option deletion', () => {
|
||||||
|
render(<ConfigSelect {...defaultProps} />)
|
||||||
|
const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
|
||||||
|
const deleteButton = optionContainer?.querySelector('div[role="button"]')
|
||||||
|
|
||||||
|
if (!deleteButton) return
|
||||||
|
fireEvent.click(deleteButton)
|
||||||
|
expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles adding new option', () => {
|
||||||
|
render(<ConfigSelect {...defaultProps} />)
|
||||||
|
const addButton = screen.getByText('appDebug.variableConfig.addOption')
|
||||||
|
|
||||||
|
fireEvent.click(addButton)
|
||||||
|
|
||||||
|
expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, ''])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('applies focus styles on input focus', () => {
|
||||||
|
render(<ConfigSelect {...defaultProps} />)
|
||||||
|
const firstInput = screen.getByDisplayValue('Option 1')
|
||||||
|
|
||||||
|
fireEvent.focus(firstInput)
|
||||||
|
|
||||||
|
expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('applies delete hover styles', () => {
|
||||||
|
render(<ConfigSelect {...defaultProps} />)
|
||||||
|
const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
|
||||||
|
const deleteButton = optionContainer?.querySelector('div[role="button"]')
|
||||||
|
|
||||||
|
if (!deleteButton) return
|
||||||
|
fireEvent.mouseEnter(deleteButton)
|
||||||
|
expect(optionContainer).toHaveClass('border-components-input-border-destructive')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders empty state correctly', () => {
|
||||||
|
render(<ConfigSelect options={[]} onChange={defaultProps.onChange} />)
|
||||||
|
|
||||||
|
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||||
|
expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
@ -51,7 +51,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({
|
|||||||
<RiDraggable className='handle h-4 w-4 cursor-grab text-text-quaternary' />
|
<RiDraggable className='handle h-4 w-4 cursor-grab text-text-quaternary' />
|
||||||
<input
|
<input
|
||||||
key={index}
|
key={index}
|
||||||
type="input"
|
type='input'
|
||||||
value={o || ''}
|
value={o || ''}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
const value = e.target.value
|
const value = e.target.value
|
||||||
@ -67,6 +67,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({
|
|||||||
onBlur={() => setFocusID(null)}
|
onBlur={() => setFocusID(null)}
|
||||||
/>
|
/>
|
||||||
<div
|
<div
|
||||||
|
role='button'
|
||||||
className='absolute right-1.5 top-1/2 block translate-y-[-50%] cursor-pointer rounded-md p-1 text-text-tertiary hover:bg-state-destructive-hover hover:text-text-destructive'
|
className='absolute right-1.5 top-1/2 block translate-y-[-50%] cursor-pointer rounded-md p-1 text-text-tertiary hover:bg-state-destructive-hover hover:text-text-destructive'
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
onChange(options.filter((_, i) => index !== i))
|
onChange(options.filter((_, i) => index !== i))
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user