This commit is contained in:
jyong 2025-05-16 17:22:17 +08:00
parent 9e72afee3c
commit 8bea88c8cc
11 changed files with 80 additions and 41 deletions

View File

@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings):
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field( HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,", description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="remote", default="database",
) )
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field( HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(

View File

@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"]) type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str) language = request.args.get("language", default="en-US", type=str)
# get pipeline templates # get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language) pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
@ -107,7 +107,7 @@ class CustomizedPipelineTemplateApi(Resource):
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first() pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found.") raise ValueError("Pipeline not found.")
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True) dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
return {"data": dsl}, 200 return {"data": dsl}, 200

View File

@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource):
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("pipeline_variables", type=dict, required=False, location="json") parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource):
"hash": data.get("hash"), "hash": data.get("hash"),
"environment_variables": data.get("environment_variables"), "environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"), "conversation_variables": data.get("conversation_variables"),
"pipeline_variables": data.get("pipeline_variables"), "rag_pipeline_variables": data.get("rag_pipeline_variables"),
} }
except json.JSONDecodeError: except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
pipeline_variables_list = args.get("pipeline_variables") or {} rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {}
pipeline_variables = { rag_pipeline_variables = {
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
for k, v in pipeline_variables_list.items() for k, v in rag_pipeline_variables_list.items()
} }
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow( workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, pipeline=pipeline,
graph=args["graph"], graph=args["graph"],
features=args["features"],
unique_hash=args.get("hash"), unique_hash=args.get("hash"),
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables, rag_pipeline_variables=rag_pipeline_variables,
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()
@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self, pipeline_id):
return { return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
} }
@ -792,5 +790,5 @@ api.add_resource(
) )
api.add_resource( api.add_resource(
DatasourceListApi, DatasourceListApi,
"/rag/pipelines/datasources", "/rag/pipelines/datasource-plugins",
) )

View File

@ -153,6 +153,7 @@ pipeline_import_fields = {
"id": fields.String, "id": fields.String,
"status": fields.String, "status": fields.String,
"pipeline_id": fields.String, "pipeline_id": fields.String,
"dataset_id": fields.String,
"current_dsl_version": fields.String, "current_dsl_version": fields.String,
"imported_dsl_version": fields.String, "imported_dsl_version": fields.String,
"error": fields.String, "error": fields.String,

View File

@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def pipeline(self):
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates" __tablename__ = "pipeline_customized_templates"
@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined]
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
mode = db.Column(db.String(255), nullable=False)
workflow_id = db.Column(StringUUID, nullable=True) workflow_id = db.Column(StringUUID, nullable=True)
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()

View File

@ -352,7 +352,7 @@ class Workflow(Base):
) )
@property @property
def pipeline_variables(self) -> dict[str, Sequence[Variable]]: def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]:
# TODO: find some way to init `self._conversation_variables` when instance created. # TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None: if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}" self._rag_pipeline_variables = "{}"
@ -363,8 +363,8 @@ class Workflow(Base):
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
return results return results
@pipeline_variables.setter @rag_pipeline_variables.setter
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
self._rag_pipeline_variables = json.dumps( self._rag_pipeline_variables = json.dumps(
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, {k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
ensure_ascii=False, ensure_ascii=False,

View File

@ -40,6 +40,7 @@ from models.dataset import (
Document, Document,
DocumentSegment, DocumentSegment,
ExternalKnowledgeBindings, ExternalKnowledgeBindings,
Pipeline,
) )
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceOauthBinding from models.source import DataSourceOauthBinding
@ -248,6 +249,15 @@ class DatasetService:
raise DatasetNameDuplicateError( raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
) )
pipeline = Pipeline(
tenant_id=tenant_id,
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
created_by=current_user.id
)
db.session.add(pipeline)
db.session.flush()
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -257,7 +267,8 @@ class DatasetService:
provider="vendor", provider="vendor",
runtime_mode="rag_pipeline", runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info, icon_info=rag_pipeline_dataset_create_entity.icon_info,
created_by=current_user.id created_by=current_user.id,
pipeline_id=pipeline.id
) )
db.session.add(dataset) db.session.add(dataset)
db.session.commit() db.session.commit()
@ -282,10 +293,13 @@ class DatasetService:
runtime_mode="rag_pipeline", runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info, icon_info=rag_pipeline_dataset_create_entity.icon_info,
) )
with Session(db.engine) as session:
if rag_pipeline_dataset_create_entity.yaml_content: rag_pipeline_dsl_service = RagPipelineDslService(session)
rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline( rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset account=current_user,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset
) )
return { return {
"id": rag_pipeline_import_info.id, "id": rag_pipeline_import_info.id,

View File

@ -1,10 +1,9 @@
from typing import Optional from typing import Optional
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Pipeline, PipelineBuiltInTemplate from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@ -30,11 +29,32 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language :param language: language
:return: :return:
""" """
pipeline_templates = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter(
) PipelineBuiltInTemplate.language == language
).all()
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
recommended_pipeline_result = {
'id': pipeline_built_in_template.id,
'name': pipeline_built_in_template.name,
'pipeline_id': pipeline_model.id,
'description': pipeline_built_in_template.description,
'icon': pipeline_built_in_template.icon,
'copyright': pipeline_built_in_template.copyright,
'privacy_policy': pipeline_built_in_template.privacy_policy,
'position': pipeline_built_in_template.position,
}
dataset: Dataset = pipeline_model.dataset
if dataset:
recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result)
return {'pipeline_templates': recommended_pipelines_results}
return {"pipeline_templates": pipeline_templates}
@classmethod @classmethod
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]: def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
@ -43,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param pipeline_id: Pipeline ID :param pipeline_id: Pipeline ID
:return: :return:
""" """
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
# is in public recommended list # is in public recommended list
pipeline_template = ( pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first() db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()

View File

@ -15,7 +15,7 @@ class PipelineTemplateRetrievalFactory:
return DatabasePipelineTemplateRetrieval return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.DATABASE: case PipelineTemplateType.DATABASE:
return DatabasePipelineTemplateRetrieval return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.BUILT_IN: case PipelineTemplateType.BUILTIN:
return BuiltInPipelineTemplateRetrieval return BuiltInPipelineTemplateRetrieval
case _: case _:
raise ValueError(f"invalid fetch recommended apps mode: {mode}") raise ValueError(f"invalid fetch recommended apps mode: {mode}")

View File

@ -42,6 +42,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import PipelineT
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
class RagPipelineService: class RagPipelineService:
@staticmethod @staticmethod
def get_pipeline_templates( def get_pipeline_templates(
@ -49,7 +50,7 @@ class RagPipelineService:
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]: ) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
if type == "built-in": if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language) result = retrieval_instance.get_pipeline_templates(language)
if not result.get("pipeline_templates") and language != "en-US": if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
@ -57,7 +58,7 @@ class RagPipelineService:
return result.get("pipeline_templates") return result.get("pipeline_templates")
else: else:
mode = "customized" mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language) result = retrieval_instance.get_pipeline_templates(language)
return result.get("pipeline_templates") return result.get("pipeline_templates")
@ -200,7 +201,7 @@ class RagPipelineService:
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
pipeline_variables: dict[str, Sequence[Variable]], rag_pipeline_variables: dict[str, Sequence[Variable]],
) -> Workflow: ) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@ -217,15 +218,18 @@ class RagPipelineService:
workflow = Workflow( workflow = Workflow(
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
app_id=pipeline.id, app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value, type=WorkflowType.RAG_PIPELINE.value,
version="draft", version="draft",
graph=json.dumps(graph), graph=json.dumps(graph),
created_by=account.id, created_by=account.id,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables, rag_pipeline_variables=rag_pipeline_variables,
) )
db.session.add(workflow) db.session.add(workflow)
db.session.flush()
pipeline.workflow_id = workflow.id
# update draft workflow if found # update draft workflow if found
else: else:
workflow.graph = json.dumps(graph) workflow.graph = json.dumps(graph)
@ -233,7 +237,7 @@ class RagPipelineService:
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables workflow.conversation_variables = conversation_variables
workflow.pipeline_variables = pipeline_variables workflow.rag_pipeline_variables = rag_pipeline_variables
# commit db session changes # commit db session changes
db.session.commit() db.session.commit()

View File

@ -516,17 +516,14 @@ class RagPipelineDslService:
dependencies: Optional[list[PluginDependency]] = None, dependencies: Optional[list[PluginDependency]] = None,
) -> Pipeline: ) -> Pipeline:
"""Create a new app or update an existing one.""" """Create a new app or update an existing one."""
pipeline_data = data.get("pipeline", {}) pipeline_data = data.get("rag_pipeline", {})
pipeline_mode = pipeline_data.get("mode")
if not pipeline_mode:
raise ValueError("loss pipeline mode")
# Set icon type # Set icon type
icon_type_value = icon_type or pipeline_data.get("icon_type") icon_type_value = pipeline_data.get("icon_type")
if icon_type_value in ["emoji", "link"]: if icon_type_value in ["emoji", "link"]:
icon_type = icon_type_value icon_type = icon_type_value
else: else:
icon_type = "emoji" icon_type = "emoji"
icon = icon or str(pipeline_data.get("icon", "")) icon = str(pipeline_data.get("icon", ""))
if pipeline: if pipeline:
# Update existing pipeline # Update existing pipeline
@ -544,7 +541,6 @@ class RagPipelineDslService:
pipeline = Pipeline() pipeline = Pipeline()
pipeline.id = str(uuid4()) pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id pipeline.tenant_id = account.current_tenant_id
pipeline.mode = pipeline_mode.value
pipeline.name = pipeline_data.get("name", "") pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "") pipeline.description = pipeline_data.get("description", "")
pipeline.icon_type = icon_type pipeline.icon_type = icon_type