diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 4f1dfb6391..e674a89480 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -12,8 +12,10 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from libs.login import login_required +from models.dataset import Pipeline, PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService +from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService logger = logging.getLogger(__name__) @@ -99,7 +101,14 @@ class CustomizedPipelineTemplateApi(Resource): @enterprise_license_required def post(self, template_id: str): with Session(db.engine) as session: - dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id) + template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + if not template: + raise ValueError("Customized pipeline template not found.") + pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found.") + + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) return {"data": dsl}, 200 diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 166673130f..bc2cfdeeb3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -41,8 +41,6 @@ from models.workflow import ( from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory -from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService - class RagPipelineService: @staticmethod @@ -116,20 +114,6 @@ class RagPipelineService: db.delete(customized_template) db.commit() - @classmethod - def export_template_rag_pipeline_dsl(cls, template_id: str) -> str: - """ - Export template rag pipeline dsl - """ - template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() - if not template: - raise ValueError("Customized pipeline template not found.") - pipeline = db.session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() - if not pipeline: - raise ValueError("Pipeline not found.") - - return RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) - def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: """ Get draft workflow