diff --git a/api/controllers/console/datasets/pipeline.py b/api/controllers/console/datasets/pipeline.py new file mode 100644 index 0000000000..20a3df8a1b --- /dev/null +++ b/api/controllers/console/datasets/pipeline.py @@ -0,0 +1,42 @@ +from flask import request +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal # type: ignore + +from controllers.console import api +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) +from core.model_runtime.entities.model_entities import ModelType +from core.plugin.entities.plugin import ModelProviderID +from core.provider_manager import ProviderManager +from fields.dataset_fields import dataset_detail_fields +from libs.login import login_required +from services.dataset_service import DatasetPermissionService, DatasetService + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError("Name must be between 1 to 40 characters.") + return name + + +def _validate_description_length(description): + if len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class PipelineTemplateListApi(Resource): + @setup_required + @login_required + @account_initialization_required + @enterprise_license_required + def get(self): + type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + # get pipeline templates + return response, 200 + + +api.add_resource(PipelineTemplateListApi, "/rag/pipeline/templates") diff --git a/api/models/dataset.py b/api/models/dataset.py index d6708ac88b..1e274a31f8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1138,3 +1138,42 @@ class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] document_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by = db.Column(StringUUID, nullable=False) + +class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined] + __tablename__ = "pipeline_built_in_templates" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + app_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False) + icon = db.Column(db.JSON, nullable=False) + copyright = db.Column(db.String(255), nullable=False) + privacy_policy = db.Column(db.String(255), nullable=False) + position = db.Column(db.Integer, nullable=False) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined] + __tablename__ = "pipeline_customized_templates" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"), + db.Index("pipeline_customized_template_tenant_idx", "tenant_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=False) + icon = db.Column(db.JSON, nullable=False) + position = db.Column(db.Integer, nullable=False) + install_count = db.Column(db.Integer, nullable=False, default=0) + language = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py new file mode 100644 index 0000000000..c6d1769679 --- /dev/null +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -0,0 +1,20 @@ +import datetime +import hashlib +import os +import uuid +from typing import Any, List, Literal, Union + +from flask_login import current_user + +from models.dataset import PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore + + +class RagPipelineService: + @staticmethod + def get_pipeline_templates( + type: Literal["built-in", "customized"] = "built-in", + ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: + if type == "built-in": + return PipelineBuiltInTemplate.query.all() + else: + return PipelineCustomizedTemplate.query.all()