This commit is contained in:
jyong 2025-05-15 16:07:17 +08:00
parent 360f8a3375
commit e710a8402c
8 changed files with 165 additions and 22 deletions

View File

@ -2,13 +2,13 @@ from collections.abc import Generator
from typing import Any, Optional
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceInvokeMessage,
DatasourceParameter,
DatasourceProviderType,
)
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -44,7 +44,7 @@ class DatasourcePlugin:
datasource_parameters: dict[str, Any],
rag_pipeline_id: Optional[str] = None,
) -> Generator[DatasourceInvokeMessage, None, None]:
manager = DatasourceManager()
manager = PluginDatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
@ -64,7 +64,7 @@ class DatasourcePlugin:
datasource_parameters: dict[str, Any],
rag_pipeline_id: Optional[str] = None,
) -> Generator[DatasourceInvokeMessage, None, None]:
manager = DatasourceManager()
manager = PluginDatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)

View File

@ -7,8 +7,8 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import ToolProviderNotFoundError
from core.plugin.manager.tool import PluginToolManager
from core.datasource.errors import DatasourceProviderNotFoundError
from core.plugin.impl.tool import PluginToolManager
logger = logging.getLogger(__name__)
@ -37,9 +37,9 @@ class DatasourceManager:
return datasource_plugin_providers[provider]
manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
controller = DatasourcePluginProviderController(
entity=provider_entity.declaration,
@ -73,7 +73,7 @@ class DatasourceManager:
if provider_type == DatasourceProviderType.RAG_PIPELINE:
return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
raise DatasourceProviderNotFoundError(f"provider type {provider_type.value} not found")
@classmethod
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
@ -81,7 +81,7 @@ class DatasourceManager:
list all the datasource providers
"""
manager = PluginToolManager()
provider_entities = manager.fetch_tool_providers(tenant_id)
provider_entities = manager.fetch_datasources(tenant_id)
return [
DatasourcePluginProviderController(
entity=provider.declaration,

View File

@ -321,9 +321,6 @@ class DatasourceEntity(BaseModel):
output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:

View File

@ -192,6 +192,9 @@ class ToolProviderID(GenericProviderID):
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
self.plugin_name = f"{self.provider_name}_tool"
class DatasourceProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded)
class PluginDependency(BaseModel):
class Type(enum.StrEnum):

View File

@ -3,7 +3,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginDatasourceProviderEntity,
@ -76,6 +76,36 @@ class PluginToolManager(BasePluginClient):
return response
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
"""
Fetch datasource provider for the given tenant and plugin.
"""
datasource_provider_id = DatasourceProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for tool in data.get("declaration", {}).get("tools", []):
tool["identity"]["provider"] = datasource_provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasource",
PluginDatasourceProviderEntity,
params={"provider": datasource_provider_id.provider_name, "plugin_id": datasource_provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in response.declaration.tools:
tool.identity.provider = response.declaration.identity.name
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.

View File

@ -0,0 +1,113 @@
"""add_pipeline_info
Revision ID: b35c3db83d09
Revises: d28f2004b072
Create Date: 2025-05-15 15:58:05.179877
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'b35c3db83d09'
down_revision = 'd28f2004b072'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('pipeline_built_in_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('icon', sa.JSON(), nullable=False),
sa.Column('copyright', sa.String(length=255), nullable=False),
sa.Column('privacy_policy', sa.String(length=255), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('install_count', sa.Integer(), nullable=False),
sa.Column('language', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
)
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('icon', sa.JSON(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('install_count', sa.Integer(), nullable=False),
sa.Column('language', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
op.create_table('pipelines',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False),
sa.Column('mode', sa.String(length=255), nullable=False),
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
)
op.create_table('tool_builtin_datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
sa.Column('user_id', models.types.StringUUID(), nullable=False),
sa.Column('provider', sa.String(length=256), nullable=False),
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider')
)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True))
batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True))
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True))
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.drop_column('rag_pipeline_variables')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('chunk_structure')
batch_op.drop_column('pipeline_id')
batch_op.drop_column('runtime_mode')
batch_op.drop_column('icon_info')
batch_op.drop_column('keyword_number')
op.drop_table('tool_builtin_datasource_providers')
op.drop_table('pipelines')
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.drop_index('pipeline_customized_template_tenant_idx')
op.drop_table('pipeline_customized_templates')
op.drop_table('pipeline_built_in_templates')
# ### end Alembic commands ###

View File

@ -1149,7 +1149,7 @@ class DatasetMetadataBinding(Base):
created_by = db.Column(StringUUID, nullable=False)
class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined]
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
@ -1167,7 +1167,7 @@ class PipelineBuiltInTemplate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined]
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
@ -1187,7 +1187,7 @@ class PipelineCustomizedTemplate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Pipeline(db.Model): # type: ignore[name-defined]
class Pipeline(Base): # type: ignore[name-defined]
__tablename__ = "pipelines"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)

View File

@ -128,8 +128,8 @@ class Workflow(Base):
_conversation_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
)
_pipeline_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
)
@classmethod
@ -354,10 +354,10 @@ class Workflow(Base):
@property
def pipeline_variables(self) -> dict[str, Sequence[Variable]]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._pipeline_variables is None:
self._pipeline_variables = "{}"
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._pipeline_variables)
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = {}
for k, v in variables_dict.items():
results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()]
@ -365,7 +365,7 @@ class Workflow(Base):
@pipeline_variables.setter
def pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None:
self._pipeline_variables = json.dumps(
self._rag_pipeline_variables = json.dumps(
{k: {item.name: item.model_dump() for item in v} for k, v in values.items()},
ensure_ascii=False,
)