mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 22:25:58 +08:00
r2
This commit is contained in:
parent
9e72afee3c
commit
8bea88c8cc
@ -229,7 +229,7 @@ class HostedFetchPipelineTemplateConfig(BaseSettings):
|
|||||||
|
|
||||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||||
default="remote",
|
default="database",
|
||||||
)
|
)
|
||||||
|
|
||||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||||
|
@ -38,7 +38,7 @@ class PipelineTemplateListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
language = request.args.get("language", default="en-US", type=str)
|
language = request.args.get("language", default="en-US", type=str)
|
||||||
# get pipeline templates
|
# get pipeline templates
|
||||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||||
|
@ -90,11 +90,10 @@ class DraftRagPipelineApi(Resource):
|
|||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("hash", type=str, required=False, location="json")
|
parser.add_argument("hash", type=str, required=False, location="json")
|
||||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument("pipeline_variables", type=dict, required=False, location="json")
|
parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
@ -111,7 +110,7 @@ class DraftRagPipelineApi(Resource):
|
|||||||
"hash": data.get("hash"),
|
"hash": data.get("hash"),
|
||||||
"environment_variables": data.get("environment_variables"),
|
"environment_variables": data.get("environment_variables"),
|
||||||
"conversation_variables": data.get("conversation_variables"),
|
"conversation_variables": data.get("conversation_variables"),
|
||||||
"pipeline_variables": data.get("pipeline_variables"),
|
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return {"message": "Invalid JSON data"}, 400
|
return {"message": "Invalid JSON data"}, 400
|
||||||
@ -130,21 +129,20 @@ class DraftRagPipelineApi(Resource):
|
|||||||
conversation_variables = [
|
conversation_variables = [
|
||||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||||
]
|
]
|
||||||
pipeline_variables_list = args.get("pipeline_variables") or {}
|
rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {}
|
||||||
pipeline_variables = {
|
rag_pipeline_variables = {
|
||||||
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
|
k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v]
|
||||||
for k, v in pipeline_variables_list.items()
|
for k, v in rag_pipeline_variables_list.items()
|
||||||
}
|
}
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
graph=args["graph"],
|
graph=args["graph"],
|
||||||
features=args["features"],
|
|
||||||
unique_hash=args.get("hash"),
|
unique_hash=args.get("hash"),
|
||||||
account=current_user,
|
account=current_user,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
conversation_variables=conversation_variables,
|
||||||
pipeline_variables=pipeline_variables,
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
)
|
)
|
||||||
except WorkflowHashNotEqualError:
|
except WorkflowHashNotEqualError:
|
||||||
raise DraftWorkflowNotSync()
|
raise DraftWorkflowNotSync()
|
||||||
@ -476,7 +474,7 @@ class RagPipelineConfigApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self, pipeline_id):
|
||||||
return {
|
return {
|
||||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||||
}
|
}
|
||||||
@ -792,5 +790,5 @@ api.add_resource(
|
|||||||
)
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
DatasourceListApi,
|
DatasourceListApi,
|
||||||
"/rag/pipelines/datasources",
|
"/rag/pipelines/datasource-plugins",
|
||||||
)
|
)
|
||||||
|
@ -153,6 +153,7 @@ pipeline_import_fields = {
|
|||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"status": fields.String,
|
"status": fields.String,
|
||||||
"pipeline_id": fields.String,
|
"pipeline_id": fields.String,
|
||||||
|
"dataset_id": fields.String,
|
||||||
"current_dsl_version": fields.String,
|
"current_dsl_version": fields.String,
|
||||||
"imported_dsl_version": fields.String,
|
"imported_dsl_version": fields.String,
|
||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
|
@ -1166,6 +1166,9 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pipeline(self):
|
||||||
|
return db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||||
|
|
||||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||||
__tablename__ = "pipeline_customized_templates"
|
__tablename__ = "pipeline_customized_templates"
|
||||||
@ -1195,7 +1198,6 @@ class Pipeline(Base): # type: ignore[name-defined]
|
|||||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||||
name = db.Column(db.String(255), nullable=False)
|
name = db.Column(db.String(255), nullable=False)
|
||||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||||
mode = db.Column(db.String(255), nullable=False)
|
|
||||||
workflow_id = db.Column(StringUUID, nullable=True)
|
workflow_id = db.Column(StringUUID, nullable=True)
|
||||||
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
@ -1203,3 +1205,6 @@ class Pipeline(Base): # type: ignore[name-defined]
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
updated_by = db.Column(StringUUID, nullable=True)
|
updated_by = db.Column(StringUUID, nullable=True)
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
@property
|
||||||
|
def dataset(self):
|
||||||
|
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||||
|
@ -352,7 +352,7 @@ class Workflow(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]:
|
||||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||||
if self._rag_pipeline_variables is None:
|
if self._rag_pipeline_variables is None:
|
||||||
self._rag_pipeline_variables = "{}"
|
self._rag_pipeline_variables = "{}"
|
||||||
@ -363,8 +363,8 @@ class Workflow(Base):
|
|||||||
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
|
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@pipeline_variables.setter
|
@rag_pipeline_variables.setter
|
||||||
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
|
||||||
self._rag_pipeline_variables = json.dumps(
|
self._rag_pipeline_variables = json.dumps(
|
||||||
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
|
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
|
@ -40,6 +40,7 @@ from models.dataset import (
|
|||||||
Document,
|
Document,
|
||||||
DocumentSegment,
|
DocumentSegment,
|
||||||
ExternalKnowledgeBindings,
|
ExternalKnowledgeBindings,
|
||||||
|
Pipeline,
|
||||||
)
|
)
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from models.source import DataSourceOauthBinding
|
from models.source import DataSourceOauthBinding
|
||||||
@ -249,6 +250,15 @@ class DatasetService:
|
|||||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
name=rag_pipeline_dataset_create_entity.name,
|
||||||
|
description=rag_pipeline_dataset_create_entity.description,
|
||||||
|
created_by=current_user.id
|
||||||
|
)
|
||||||
|
db.session.add(pipeline)
|
||||||
|
db.session.flush()
|
||||||
|
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=rag_pipeline_dataset_create_entity.name,
|
name=rag_pipeline_dataset_create_entity.name,
|
||||||
@ -257,7 +267,8 @@ class DatasetService:
|
|||||||
provider="vendor",
|
provider="vendor",
|
||||||
runtime_mode="rag_pipeline",
|
runtime_mode="rag_pipeline",
|
||||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||||
created_by=current_user.id
|
created_by=current_user.id,
|
||||||
|
pipeline_id=pipeline.id
|
||||||
)
|
)
|
||||||
db.session.add(dataset)
|
db.session.add(dataset)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -282,10 +293,13 @@ class DatasetService:
|
|||||||
runtime_mode="rag_pipeline",
|
runtime_mode="rag_pipeline",
|
||||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||||
)
|
)
|
||||||
|
with Session(db.engine) as session:
|
||||||
if rag_pipeline_dataset_create_entity.yaml_content:
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline(
|
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||||
current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset
|
account=current_user,
|
||||||
|
import_mode=ImportMode.YAML_CONTENT.value,
|
||||||
|
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||||
|
dataset=dataset
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"id": rag_pipeline_import_info.id,
|
"id": rag_pipeline_import_info.id,
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Pipeline, PipelineBuiltInTemplate
|
from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
|
||||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||||
#from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
|
||||||
|
|
||||||
|
|
||||||
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||||
@ -30,11 +29,32 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||||||
:param language: language
|
:param language: language
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
pipeline_templates = (
|
|
||||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"pipeline_templates": pipeline_templates}
|
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = db.session.query(PipelineBuiltInTemplate).filter(
|
||||||
|
PipelineBuiltInTemplate.language == language
|
||||||
|
).all()
|
||||||
|
|
||||||
|
recommended_pipelines_results = []
|
||||||
|
for pipeline_built_in_template in pipeline_built_in_templates:
|
||||||
|
pipeline_model: Pipeline = pipeline_built_in_template.pipeline
|
||||||
|
|
||||||
|
recommended_pipeline_result = {
|
||||||
|
'id': pipeline_built_in_template.id,
|
||||||
|
'name': pipeline_built_in_template.name,
|
||||||
|
'pipeline_id': pipeline_model.id,
|
||||||
|
'description': pipeline_built_in_template.description,
|
||||||
|
'icon': pipeline_built_in_template.icon,
|
||||||
|
'copyright': pipeline_built_in_template.copyright,
|
||||||
|
'privacy_policy': pipeline_built_in_template.privacy_policy,
|
||||||
|
'position': pipeline_built_in_template.position,
|
||||||
|
}
|
||||||
|
dataset: Dataset = pipeline_model.dataset
|
||||||
|
if dataset:
|
||||||
|
recommended_pipeline_result['chunk_structure'] = dataset.chunk_structure
|
||||||
|
recommended_pipelines_results.append(recommended_pipeline_result)
|
||||||
|
|
||||||
|
return {'pipeline_templates': recommended_pipelines_results}
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
|
def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
|
||||||
@ -43,6 +63,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||||||
:param pipeline_id: Pipeline ID
|
:param pipeline_id: Pipeline ID
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
# is in public recommended list
|
# is in public recommended list
|
||||||
pipeline_template = (
|
pipeline_template = (
|
||||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
|
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
|
||||||
|
@ -15,7 +15,7 @@ class PipelineTemplateRetrievalFactory:
|
|||||||
return DatabasePipelineTemplateRetrieval
|
return DatabasePipelineTemplateRetrieval
|
||||||
case PipelineTemplateType.DATABASE:
|
case PipelineTemplateType.DATABASE:
|
||||||
return DatabasePipelineTemplateRetrieval
|
return DatabasePipelineTemplateRetrieval
|
||||||
case PipelineTemplateType.BUILT_IN:
|
case PipelineTemplateType.BUILTIN:
|
||||||
return BuiltInPipelineTemplateRetrieval
|
return BuiltInPipelineTemplateRetrieval
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
|
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
|
||||||
|
@ -42,6 +42,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import PipelineT
|
|||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||||
|
|
||||||
|
|
||||||
class RagPipelineService:
|
class RagPipelineService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_pipeline_templates(
|
def get_pipeline_templates(
|
||||||
@ -49,7 +50,7 @@ class RagPipelineService:
|
|||||||
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
|
) -> list[PipelineBuiltInTemplate | PipelineCustomizedTemplate]:
|
||||||
if type == "built-in":
|
if type == "built-in":
|
||||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||||
result = retrieval_instance.get_pipeline_templates(language)
|
result = retrieval_instance.get_pipeline_templates(language)
|
||||||
if not result.get("pipeline_templates") and language != "en-US":
|
if not result.get("pipeline_templates") and language != "en-US":
|
||||||
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||||
@ -57,7 +58,7 @@ class RagPipelineService:
|
|||||||
return result.get("pipeline_templates")
|
return result.get("pipeline_templates")
|
||||||
else:
|
else:
|
||||||
mode = "customized"
|
mode = "customized"
|
||||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)
|
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||||
result = retrieval_instance.get_pipeline_templates(language)
|
result = retrieval_instance.get_pipeline_templates(language)
|
||||||
return result.get("pipeline_templates")
|
return result.get("pipeline_templates")
|
||||||
|
|
||||||
@ -200,7 +201,7 @@ class RagPipelineService:
|
|||||||
account: Account,
|
account: Account,
|
||||||
environment_variables: Sequence[Variable],
|
environment_variables: Sequence[Variable],
|
||||||
conversation_variables: Sequence[Variable],
|
conversation_variables: Sequence[Variable],
|
||||||
pipeline_variables: dict[str, Sequence[Variable]],
|
rag_pipeline_variables: dict[str, Sequence[Variable]],
|
||||||
) -> Workflow:
|
) -> Workflow:
|
||||||
"""
|
"""
|
||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
@ -217,15 +218,18 @@ class RagPipelineService:
|
|||||||
workflow = Workflow(
|
workflow = Workflow(
|
||||||
tenant_id=pipeline.tenant_id,
|
tenant_id=pipeline.tenant_id,
|
||||||
app_id=pipeline.id,
|
app_id=pipeline.id,
|
||||||
|
features="{}",
|
||||||
type=WorkflowType.RAG_PIPELINE.value,
|
type=WorkflowType.RAG_PIPELINE.value,
|
||||||
version="draft",
|
version="draft",
|
||||||
graph=json.dumps(graph),
|
graph=json.dumps(graph),
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
conversation_variables=conversation_variables,
|
conversation_variables=conversation_variables,
|
||||||
pipeline_variables=pipeline_variables,
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
)
|
)
|
||||||
db.session.add(workflow)
|
db.session.add(workflow)
|
||||||
|
db.session.flush()
|
||||||
|
pipeline.workflow_id = workflow.id
|
||||||
# update draft workflow if found
|
# update draft workflow if found
|
||||||
else:
|
else:
|
||||||
workflow.graph = json.dumps(graph)
|
workflow.graph = json.dumps(graph)
|
||||||
@ -233,7 +237,7 @@ class RagPipelineService:
|
|||||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow.environment_variables = environment_variables
|
workflow.environment_variables = environment_variables
|
||||||
workflow.conversation_variables = conversation_variables
|
workflow.conversation_variables = conversation_variables
|
||||||
workflow.pipeline_variables = pipeline_variables
|
workflow.rag_pipeline_variables = rag_pipeline_variables
|
||||||
# commit db session changes
|
# commit db session changes
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
@ -516,17 +516,14 @@ class RagPipelineDslService:
|
|||||||
dependencies: Optional[list[PluginDependency]] = None,
|
dependencies: Optional[list[PluginDependency]] = None,
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""Create a new app or update an existing one."""
|
"""Create a new app or update an existing one."""
|
||||||
pipeline_data = data.get("pipeline", {})
|
pipeline_data = data.get("rag_pipeline", {})
|
||||||
pipeline_mode = pipeline_data.get("mode")
|
|
||||||
if not pipeline_mode:
|
|
||||||
raise ValueError("loss pipeline mode")
|
|
||||||
# Set icon type
|
# Set icon type
|
||||||
icon_type_value = icon_type or pipeline_data.get("icon_type")
|
icon_type_value = pipeline_data.get("icon_type")
|
||||||
if icon_type_value in ["emoji", "link"]:
|
if icon_type_value in ["emoji", "link"]:
|
||||||
icon_type = icon_type_value
|
icon_type = icon_type_value
|
||||||
else:
|
else:
|
||||||
icon_type = "emoji"
|
icon_type = "emoji"
|
||||||
icon = icon or str(pipeline_data.get("icon", ""))
|
icon = str(pipeline_data.get("icon", ""))
|
||||||
|
|
||||||
if pipeline:
|
if pipeline:
|
||||||
# Update existing pipeline
|
# Update existing pipeline
|
||||||
@ -544,7 +541,6 @@ class RagPipelineDslService:
|
|||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.id = str(uuid4())
|
pipeline.id = str(uuid4())
|
||||||
pipeline.tenant_id = account.current_tenant_id
|
pipeline.tenant_id = account.current_tenant_id
|
||||||
pipeline.mode = pipeline_mode.value
|
|
||||||
pipeline.name = pipeline_data.get("name", "")
|
pipeline.name = pipeline_data.get("name", "")
|
||||||
pipeline.description = pipeline_data.get("description", "")
|
pipeline.description = pipeline_data.get("description", "")
|
||||||
pipeline.icon_type = icon_type
|
pipeline.icon_type = icon_type
|
||||||
|
Loading…
x
Reference in New Issue
Block a user