This commit is contained in:
jyong 2025-05-16 17:22:17 +08:00
parent 9e72afee3c
commit 8bea88c8cc
11 changed files with 80 additions and 41 deletions

View File

@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings):
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="remote",
default="database",
)
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(

View File

@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
@ -107,7 +107,7 @@ class CustomizedPipelineTemplateApi(Resource):
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found.")
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
return {"data": dsl}, 200

View File

@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource):
if "application/json" in content_type:
parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"pipeline_variables": data.get("pipeline_variables"),
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
pipeline_variables_list = args.get("pipeline_variables") or {}
pipeline_variables = {
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {}
rag_pipeline_variables = {
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
for k, v in pipeline_variables_list.items()
for k, v in rag_pipeline_variables_list.items()
}
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=args["graph"],
features=args["features"],
unique_hash=args.get("hash"),
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables,
rag_pipeline_variables=rag_pipeline_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
@ -792,5 +790,5 @@ api.add_resource(
)
api.add_resource(
DatasourceListApi,
"/rag/pipelines/datasources",
"/rag/pipelines/datasource-plugins",
)

View File

@ -153,6 +153,7 @@ pipeline_import_fields = {
"id": fields.String,
"status": fields.String,
"pipeline_id": fields.String,
"dataset_id": fields.String,
"current_dsl_version": fields.String,
"imported_dsl_version": fields.String,
"error": fields.String,

View File

@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def pipeline(self):
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined]
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
mode = db.Column(db.String(255), nullable=False)
workflow_id = db.Column(StringUUID, nullable=True)
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()

View File

@ -352,7 +352,7 @@ class Workflow(Base):
)
@property
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
@ -363,8 +363,8 @@ class Workflow(Base):
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
return results
@pipeline_variables.setter
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
self._rag_pipeline_variables = json.dumps(
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
ensure_ascii=False,

View File

@ -40,6 +40,7 @@ from models.dataset import (
Document,
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
)
from models.model import UploadFile
from models.source import DataSourceOauthBinding
@ -248,6 +249,15 @@ class DatasetService:
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
pipeline = Pipeline(
tenant_id=tenant_id,
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
created_by=current_user.id
)
db.session.add(pipeline)
db.session.flush()
dataset = Dataset(
tenant_id=tenant_id,
@ -257,7 +267,8 @@ class DatasetService:
provider="vendor",
runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
created_by=current_user.id
created_by=current_user.id,
pipeline_id=pipeline.id
)
db.session.add(dataset)
db.session.commit()
@ -282,10 +293,13 @@ class DatasetService:
runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
if rag_pipeline_dataset_create_entity.yaml_content:
rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline(
current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
account=current_user,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=dataset
)
return {
"id": rag_pipeline_import_info.id,

View File

@ -1,10 +1,9 @@
from typing import Optional
from extensions.ext_database import db
from models.dataset import Pipeline, PipelineBuiltInTemplate
from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
@ -30,11 +29,32 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
pipeline_templates = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
)
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter(
PipelineBuiltInTemplate.language == language
).all()
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
recommended_pipeline_result = {
'id': pipeline_built_in_template.id,
'name': pipeline_built_in_template.name,
'pipeline_id': pipeline_model.id,
'description': pipeline_built_in_template.description,
'icon': pipeline_built_in_template.icon,
'copyright': pipeline_built_in_template.copyright,
'privacy_policy': pipeline_built_in_template.privacy_policy,
'position': pipeline_built_in_template.position,
}
dataset: Dataset = pipeline_model.dataset
if dataset:
recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure
recommended_pipelines_results.append(recommended_pipeline_result)
return {'pipeline_templates': recommended_pipelines_results}
return {"pipeline_templates": pipeline_templates}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
@ -43,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param pipeline_id: Pipeline ID
:return:
"""
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()

View File

@ -15,7 +15,7 @@ class PipelineTemplateRetrievalFactory:
return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.DATABASE:
return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.BUILT_IN:
case PipelineTemplateType.BUILTIN:
return BuiltInPipelineTemplateRetrieval
case _:
raise ValueError(f"invalid fetch recommended apps mode: {mode}")

View File

@ -42,6 +42,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import PipelineT
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
class RagPipelineService:
@staticmethod
def get_pipeline_templates(
@ -49,7 +50,7 @@ class RagPipelineService:
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
@ -57,7 +58,7 @@ class RagPipelineService:
return result.get("pipeline_templates")
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
return result.get("pipeline_templates")
@ -200,7 +201,7 @@ class RagPipelineService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
pipeline_variables: dict[str, Sequence[Variable]],
rag_pipeline_variables: dict[str, Sequence[Variable]],
) -> Workflow:
"""
Sync draft workflow
@ -217,15 +218,18 @@ class RagPipelineService:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
pipeline_variables=pipeline_variables,
rag_pipeline_variables=rag_pipeline_variables,
)
db.session.add(workflow)
db.session.flush()
pipeline.workflow_id = workflow.id
# update draft workflow if found
else:
workflow.graph = json.dumps(graph)
@ -233,7 +237,7 @@ class RagPipelineService:
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.pipeline_variables = pipeline_variables
workflow.rag_pipeline_variables = rag_pipeline_variables
# commit db session changes
db.session.commit()

View File

@ -516,17 +516,14 @@ class RagPipelineDslService:
dependencies: Optional[list[PluginDependency]] = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
pipeline_data = data.get("pipeline", {})
pipeline_mode = pipeline_data.get("mode")
if not pipeline_mode:
raise ValueError("loss pipeline mode")
pipeline_data = data.get("rag_pipeline", {})
# Set icon type
icon_type_value = icon_type or pipeline_data.get("icon_type")
icon_type_value = pipeline_data.get("icon_type")
if icon_type_value in ["emoji", "link"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon = icon or str(pipeline_data.get("icon", ""))
icon = str(pipeline_data.get("icon", ""))
if pipeline:
# Update existing pipeline
@ -544,7 +541,6 @@ class RagPipelineDslService:
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.mode = pipeline_mode.value
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.icon_type = icon_type