diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 7633ffcf8a..3e57f24ff5 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings): HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field( description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,", - default="remote", + default="database", ) HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field( diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index e674a89480..44296d5a31 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource): @account_initialization_required @enterprise_license_required def get(self): - type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) + type = request.args.get("type", default="built-in", type=str) language = request.args.get("language", default="en-US", type=str) # get pipeline templates pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) @@ -107,7 +107,7 @@ class CustomizedPipelineTemplateApi(Resource): pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found.") - + dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) return {"data": dsl}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index b348f7a796..c76014d0a3 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource): if "application/json" in content_type: parser = reqparse.RequestParser() parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") - parser.add_argument("features", type=dict, required=True, nullable=False, location="json") parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") - parser.add_argument("pipeline_variables", type=dict, required=False, location="json") + parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), - "pipeline_variables": data.get("pipeline_variables"), + "rag_pipeline_variables": data.get("rag_pipeline_variables"), } except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 @@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - pipeline_variables_list = args.get("pipeline_variables") or {} - pipeline_variables = { + rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {} + rag_pipeline_variables = { k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] - for k, v in pipeline_variables_list.items() + for k, v in rag_pipeline_variables_list.items() } rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=args["graph"], - features=args["features"], unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, - pipeline_variables=pipeline_variables, + rag_pipeline_variables=rag_pipeline_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + def get(self, pipeline_id): return { "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, } @@ -792,5 +790,5 @@ api.add_resource( ) api.add_resource( DatasourceListApi, - "/rag/pipelines/datasources", + "/rag/pipelines/datasource-plugins", ) diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py index 0bb74e3259..cedc13ed0d 100644 --- a/api/fields/rag_pipeline_fields.py +++ b/api/fields/rag_pipeline_fields.py @@ -153,6 +153,7 @@ pipeline_import_fields = { "id": fields.String, "status": fields.String, "pipeline_id": fields.String, + "dataset_id": fields.String, "current_dsl_version": fields.String, "imported_dsl_version": fields.String, "error": fields.String, diff --git a/api/models/dataset.py b/api/models/dataset.py index e60f110aef..0ed59c898f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property + def pipeline(self): + return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] __tablename__ = "pipeline_customized_templates" @@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined] tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) - mode = db.Column(db.String(255), nullable=False) workflow_id = db.Column(StringUUID, nullable=True) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() diff --git a/api/models/workflow.py b/api/models/workflow.py index b6b56ad520..5cb413b6a6 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -352,7 +352,7 @@ class Workflow(Base): ) @property - def pipeline_variables(self) -> dict[str, Sequence[Variable]]: + def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._rag_pipeline_variables is None: self._rag_pipeline_variables = "{}" @@ -363,8 +363,8 @@ class Workflow(Base): results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] return results - @pipeline_variables.setter - def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: + @rag_pipeline_variables.setter + def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: self._rag_pipeline_variables = json.dumps( {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, ensure_ascii=False, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac45981ee5..0f5069f052 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -40,6 +40,7 @@ from models.dataset import ( Document, DocumentSegment, ExternalKnowledgeBindings, + Pipeline, ) from models.model import UploadFile from models.source import DataSourceOauthBinding @@ -248,6 +249,15 @@ class DatasetService: raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) + + pipeline = Pipeline( + tenant_id=tenant_id, + name=rag_pipeline_dataset_create_entity.name, + description=rag_pipeline_dataset_create_entity.description, + created_by=current_user.id + ) + db.session.add(pipeline) + db.session.flush() dataset = Dataset( tenant_id=tenant_id, @@ -257,7 +267,8 @@ class DatasetService: provider="vendor", runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, - created_by=current_user.id + created_by=current_user.id, + pipeline_id=pipeline.id ) db.session.add(dataset) db.session.commit() @@ -282,10 +293,13 @@ class DatasetService: runtime_mode="rag_pipeline", icon_info=rag_pipeline_dataset_create_entity.icon_info, ) - - if rag_pipeline_dataset_create_entity.yaml_content: - rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline( - current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset + with Session(db.engine) as session: + rag_pipeline_dsl_service = RagPipelineDslService(session) + rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline( + account=current_user, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=rag_pipeline_dataset_create_entity.yaml_content, + dataset=dataset ) return { "id": rag_pipeline_import_info.id, diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index f6ab5c9064..bda29c804c 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,10 +1,9 @@ from typing import Optional from extensions.ext_database import db -from models.dataset import Pipeline, PipelineBuiltInTemplate +from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType -#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): @@ -30,11 +29,32 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_templates = ( - db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() - ) + + pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter( + PipelineBuiltInTemplate.language == language + ).all() + + recommended_pipelines_results = [] + for pipeline_built_in_template in pipeline_built_in_templates: + pipeline_model: Pipeline = pipeline_built_in_template.pipeline + + recommended_pipeline_result = { + 'id': pipeline_built_in_template.id, + 'name': pipeline_built_in_template.name, + 'pipeline_id': pipeline_model.id, + 'description': pipeline_built_in_template.description, + 'icon': pipeline_built_in_template.icon, + 'copyright': pipeline_built_in_template.copyright, + 'privacy_policy': pipeline_built_in_template.privacy_policy, + 'position': pipeline_built_in_template.position, + } + dataset: Dataset = pipeline_model.dataset + if dataset: + recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure + recommended_pipelines_results.append(recommended_pipeline_result) + + return {'pipeline_templates': recommended_pipelines_results} - return {"pipeline_templates": pipeline_templates} @classmethod def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: @@ -43,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param pipeline_id: Pipeline ID :return: """ + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService # is in public recommended list pipeline_template = ( db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py index 37e40bf6a0..aa8a6298d7 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_factory.py @@ -15,7 +15,7 @@ class PipelineTemplateRetrievalFactory: return DatabasePipelineTemplateRetrieval case PipelineTemplateType.DATABASE: return DatabasePipelineTemplateRetrieval - case PipelineTemplateType.BUILT_IN: + case PipelineTemplateType.BUILTIN: return BuiltInPipelineTemplateRetrieval case _: raise ValueError(f"invalid fetch recommended apps mode: {mode}") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index bc2cfdeeb3..f380bc32d7 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -42,6 +42,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import PipelineT from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory + class RagPipelineService: @staticmethod def get_pipeline_templates( @@ -49,7 +50,7 @@ class RagPipelineService: ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() @@ -57,7 +58,7 @@ class RagPipelineService: return result.get("pipeline_templates") else: mode = "customized" - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() result = retrieval_instance.get_pipeline_templates(language) return result.get("pipeline_templates") @@ -200,7 +201,7 @@ class RagPipelineService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], - pipeline_variables: dict[str, Sequence[Variable]], + rag_pipeline_variables: dict[str, Sequence[Variable]], ) -> Workflow: """ Sync draft workflow @@ -217,15 +218,18 @@ class RagPipelineService: workflow = Workflow( tenant_id=pipeline.tenant_id, app_id=pipeline.id, + features="{}", type=WorkflowType.RAG_PIPELINE.value, version="draft", graph=json.dumps(graph), created_by=account.id, environment_variables=environment_variables, conversation_variables=conversation_variables, - pipeline_variables=pipeline_variables, + rag_pipeline_variables=rag_pipeline_variables, ) db.session.add(workflow) + db.session.flush() + pipeline.workflow_id = workflow.id # update draft workflow if found else: workflow.graph = json.dumps(graph) @@ -233,7 +237,7 @@ class RagPipelineService: workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables - workflow.pipeline_variables = pipeline_variables + workflow.rag_pipeline_variables = rag_pipeline_variables # commit db session changes db.session.commit() diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index e50caa9756..3664c988e5 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -516,17 +516,14 @@ class RagPipelineDslService: dependencies: Optional[list[PluginDependency]] = None, ) -> Pipeline: """Create a new app or update an existing one.""" - pipeline_data = data.get("pipeline", {}) - pipeline_mode = pipeline_data.get("mode") - if not pipeline_mode: - raise ValueError("loss pipeline mode") + pipeline_data = data.get("rag_pipeline", {}) # Set icon type - icon_type_value = icon_type or pipeline_data.get("icon_type") + icon_type_value = pipeline_data.get("icon_type") if icon_type_value in ["emoji", "link"]: icon_type = icon_type_value else: icon_type = "emoji" - icon = icon or str(pipeline_data.get("icon", "")) + icon = str(pipeline_data.get("icon", "")) if pipeline: # Update existing pipeline @@ -544,7 +541,6 @@ class RagPipelineDslService: pipeline = Pipeline() pipeline.id = str(uuid4()) pipeline.tenant_id = account.current_tenant_id - pipeline.mode = pipeline_mode.value pipeline.name = pipeline_data.get("name", "") pipeline.description = pipeline_data.get("description", "") pipeline.icon_type = icon_type