mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 22:25:57 +08:00
r2
This commit is contained in:
parent
9e72afee3c
commit
8bea88c8cc
@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
default="database",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
|
@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
@ -107,7 +107,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found.")
|
||||
|
||||
|
||||
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
|
||||
return {"data": dsl}, 200
|
||||
|
||||
|
@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource):
|
||||
if "application/json" in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
|
||||
parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource):
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"pipeline_variables": data.get("pipeline_variables"),
|
||||
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
pipeline_variables_list = args.get("pipeline_variables") or {}
|
||||
pipeline_variables = {
|
||||
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {}
|
||||
rag_pipeline_variables = {
|
||||
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
|
||||
for k, v in pipeline_variables_list.items()
|
||||
for k, v in rag_pipeline_variables_list.items()
|
||||
}
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=args["graph"],
|
||||
features=args["features"],
|
||||
unique_hash=args.get("hash"),
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
pipeline_variables=pipeline_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
def get(self, pipeline_id):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
@ -792,5 +790,5 @@ api.add_resource(
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceListApi,
|
||||
"/rag/pipelines/datasources",
|
||||
"/rag/pipelines/datasource-plugins",
|
||||
)
|
||||
|
@ -153,6 +153,7 @@ pipeline_import_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"pipeline_id": fields.String,
|
||||
"dataset_id": fields.String,
|
||||
"current_dsl_version": fields.String,
|
||||
"imported_dsl_version": fields.String,
|
||||
"error": fields.String,
|
||||
|
@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def pipeline(self):
|
||||
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
mode = db.Column(db.String(255), nullable=False)
|
||||
workflow_id = db.Column(StringUUID, nullable=True)
|
||||
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
|
@ -352,7 +352,7 @@ class Workflow(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
||||
def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
@ -363,8 +363,8 @@ class Workflow(Base):
|
||||
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
|
||||
return results
|
||||
|
||||
@pipeline_variables.setter
|
||||
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
|
||||
ensure_ascii=False,
|
||||
|
@ -40,6 +40,7 @@ from models.dataset import (
|
||||
Document,
|
||||
DocumentSegment,
|
||||
ExternalKnowledgeBindings,
|
||||
Pipeline,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceOauthBinding
|
||||
@ -248,6 +249,15 @@ class DatasetService:
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
description=rag_pipeline_dataset_create_entity.description,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.session.add(pipeline)
|
||||
db.session.flush()
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@ -257,7 +267,8 @@ class DatasetService:
|
||||
provider="vendor",
|
||||
runtime_mode="rag_pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
created_by=current_user.id
|
||||
created_by=current_user.id,
|
||||
pipeline_id=pipeline.id
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
@ -282,10 +293,13 @@ class DatasetService:
|
||||
runtime_mode="rag_pipeline",
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
|
||||
if rag_pipeline_dataset_create_entity.yaml_content:
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline(
|
||||
current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=dataset
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
|
@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline, PipelineBuiltInTemplate
|
||||
from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
@ -30,11 +29,32 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
pipeline_templates = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
|
||||
)
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter(
|
||||
PipelineBuiltInTemplate.language == language
|
||||
).all()
|
||||
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
|
||||
|
||||
recommended_pipeline_result = {
|
||||
'id': pipeline_built_in_template.id,
|
||||
'name': pipeline_built_in_template.name,
|
||||
'pipeline_id': pipeline_model.id,
|
||||
'description': pipeline_built_in_template.description,
|
||||
'icon': pipeline_built_in_template.icon,
|
||||
'copyright': pipeline_built_in_template.copyright,
|
||||
'privacy_policy': pipeline_built_in_template.privacy_policy,
|
||||
'position': pipeline_built_in_template.position,
|
||||
}
|
||||
dataset: Dataset = pipeline_model.dataset
|
||||
if dataset:
|
||||
recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure
|
||||
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||
|
||||
return {'pipeline_templates': recommended_pipelines_results}
|
||||
|
||||
return {"pipeline_templates": pipeline_templates}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
|
||||
@ -43,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param pipeline_id: Pipeline ID
|
||||
:return:
|
||||
"""
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
|
||||
|
@ -15,7 +15,7 @@ class PipelineTemplateRetrievalFactory:
|
||||
return DatabasePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.DATABASE:
|
||||
return DatabasePipelineTemplateRetrieval
|
||||
case PipelineTemplateType.BUILT_IN:
|
||||
case PipelineTemplateType.BUILTIN:
|
||||
return BuiltInPipelineTemplateRetrieval
|
||||
case _:
|
||||
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
|
||||
|
@ -42,6 +42,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import PipelineT
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
|
||||
|
||||
class RagPipelineService:
|
||||
@staticmethod
|
||||
def get_pipeline_templates(
|
||||
@ -49,7 +50,7 @@ class RagPipelineService:
|
||||
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
if not result.get("pipeline_templates") and language != "en-US":
|
||||
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||
@ -57,7 +58,7 @@ class RagPipelineService:
|
||||
return result.get("pipeline_templates")
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
return result.get("pipeline_templates")
|
||||
|
||||
@ -200,7 +201,7 @@ class RagPipelineService:
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
pipeline_variables: dict[str, Sequence[Variable]],
|
||||
rag_pipeline_variables: dict[str, Sequence[Variable]],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
@ -217,15 +218,18 @@ class RagPipelineService:
|
||||
workflow = Workflow(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
features="{}",
|
||||
type=WorkflowType.RAG_PIPELINE.value,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
pipeline_variables=pipeline_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
pipeline.workflow_id = workflow.id
|
||||
# update draft workflow if found
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
@ -233,7 +237,7 @@ class RagPipelineService:
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.pipeline_variables = pipeline_variables
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
|
@ -516,17 +516,14 @@ class RagPipelineDslService:
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
pipeline_data = data.get("pipeline", {})
|
||||
pipeline_mode = pipeline_data.get("mode")
|
||||
if not pipeline_mode:
|
||||
raise ValueError("loss pipeline mode")
|
||||
pipeline_data = data.get("rag_pipeline", {})
|
||||
# Set icon type
|
||||
icon_type_value = icon_type or pipeline_data.get("icon_type")
|
||||
icon_type_value = pipeline_data.get("icon_type")
|
||||
if icon_type_value in ["emoji", "link"]:
|
||||
icon_type = icon_type_value
|
||||
else:
|
||||
icon_type = "emoji"
|
||||
icon = icon or str(pipeline_data.get("icon", ""))
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
if pipeline:
|
||||
# Update existing pipeline
|
||||
@ -544,7 +541,6 @@ class RagPipelineDslService:
|
||||
pipeline = Pipeline()
|
||||
pipeline.id = str(uuid4())
|
||||
pipeline.tenant_id = account.current_tenant_id
|
||||
pipeline.mode = pipeline_mode.value
|
||||
pipeline.name = pipeline_data.get("name", "")
|
||||
pipeline.description = pipeline_data.get("description", "")
|
||||
pipeline.icon_type = icon_type
|
||||
|
Loading…
x
Reference in New Issue
Block a user