diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 864f3644d5..4f1dfb6391 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restful import Resource, reqparse # type: ignore # type: ignore +from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session from controllers.console import api from controllers.console.wraps import ( @@ -9,6 +10,7 @@ from controllers.console.wraps import ( enterprise_license_required, setup_required, ) +from extensions.ext_database import db from libs.login import login_required from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -91,6 +93,15 @@ class CustomizedPipelineTemplateApi(Resource): RagPipelineService.delete_customized_pipeline_template(template_id) return 200 + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def post(self, template_id: str): + with Session(db.engine) as session: + dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id) + return {"data": dsl}, 200 + api.add_resource( PipelineTemplateListApi, @@ -102,5 +113,5 @@ api.add_resource( ) api.add_resource( CustomizedPipelineTemplateApi, - "/rag/pipeline/templates/", + "/rag/pipeline/customized/templates/", ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 2275c32f63..166673130f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -41,6 +41,7 @@ from models.workflow import ( from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class RagPipelineService: @@ -115,6 +116,20 @@ class RagPipelineService: db.delete(customized_template) db.commit() + @classmethod + def export_template_rag_pipeline_dsl(cls, template_id: str) -> str: + """ + Export template rag pipeline dsl + """ + template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + if not template: + raise ValueError("Customized pipeline template not found.") + pipeline = db.session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found.") + + return RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) + def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: """ Get draft workflow