Merge branch 'feat/r2' into deploy/dev

This commit is contained in:
jyong 2025-05-16 13:46:02 +08:00
commit 3048d97e54
2 changed files with 10 additions and 17 deletions

View File

@ -12,8 +12,10 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
from models.dataset import Pipeline, PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -99,7 +101,14 @@ class CustomizedPipelineTemplateApi(Resource):
@enterprise_license_required @enterprise_license_required
def post(self, template_id: str): def post(self, template_id: str):
with Session(db.engine) as session: with Session(db.engine) as session:
dsl = RagPipelineService.export_template_rag_pipeline_dsl(template_id) template = session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
if not template:
raise ValueError("Customized pipeline template not found.")
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
return {"data": dsl}, 200 return {"data": dsl}, 200

View File

@ -41,8 +41,6 @@ from models.workflow import (
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineService: class RagPipelineService:
@staticmethod @staticmethod
@ -116,20 +114,6 @@ class RagPipelineService:
db.delete(customized_template) db.delete(customized_template)
db.commit() db.commit()
@classmethod
def export_template_rag_pipeline_dsl(cls, template_id: str) -> str:
"""
Export template rag pipeline dsl
"""
template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
if not template:
raise ValueError("Customized pipeline template not found.")
pipeline = db.session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
return RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
""" """
Get draft workflow Get draft workflow