This commit is contained in:
jyong 2025-05-16 12:02:35 +08:00
parent 4ff971c8a3
commit 7b0d38f7d3
2 changed files with 28 additions and 2 deletions

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restful import Resource, reqparse # type: ignore # type: ignore
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.wraps import (
@ -9,6 +10,7 @@ from controllers.console.wraps import (
enterprise_license_required,
setup_required,
)
from extensions.ext_database import db
from libs.login import login_required
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -91,6 +93,15 @@ class CustomizedPipelineTemplateApi(Resource):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id)
return {"data": dsl}, 200
api.add_resource(
PipelineTemplateListApi,
@ -102,5 +113,5 @@ api.add_resource(
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/templates/<string:template_id>",
"/rag/pipeline/customized/templates/<string:template_id>",
)

View File

@ -41,6 +41,7 @@ from models.workflow import (
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineService:
@ -115,6 +116,20 @@ class RagPipelineService:
db.delete(customized_template)
db.commit()
@classmethod
def export_template_rag_pipeline_dsl(cls, template_id: str) -> str:
"""
Export template rag pipeline dsl
"""
template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
if not template:
raise ValueError("Customized pipeline template not found.")
pipeline = db.session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
return RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
"""
Get draft workflow