diff --git a/pkg/modules/tracefunnel/impltracefunnel/handler_test.go b/pkg/modules/tracefunnel/impltracefunnel/handler_test.go index fa113521b8..513c6aaaea 100644 --- a/pkg/modules/tracefunnel/impltracefunnel/handler_test.go +++ b/pkg/modules/tracefunnel/impltracefunnel/handler_test.go @@ -395,7 +395,13 @@ func TestHandler_Save(t *testing.T) { } mockModule.On("Get", req.Context(), reqBody.FunnelID.String()).Return(existingFunnel, nil) - mockModule.On("Save", req.Context(), mock.AnythingOfType("*tracefunnels.Funnel"), "user-123", orgID).Return(nil) + mockModule.On("Save", req.Context(), mock.MatchedBy(func(f *traceFunnels.Funnel) bool { + return f.ID.String() == reqBody.FunnelID.String() && + f.Name == existingFunnel.Name && + f.Description == reqBody.Description && + f.UpdatedBy == "user-123" && + f.OrgID.String() == orgID + }), "user-123", orgID).Return(nil) mockModule.On("GetFunnelMetadata", req.Context(), reqBody.FunnelID.String()).Return(int64(0), int64(0), reqBody.Description, nil) handler.Save(rr, req) diff --git a/pkg/modules/tracefunnel/impltracefunnel/module.go b/pkg/modules/tracefunnel/impltracefunnel/module.go index 5ff32c6696..6fd9097042 100644 --- a/pkg/modules/tracefunnel/impltracefunnel/module.go +++ b/pkg/modules/tracefunnel/impltracefunnel/module.go @@ -70,20 +70,12 @@ func (module *module) List(ctx context.Context, orgID string) ([]*traceFunnels.F return nil, fmt.Errorf("invalid org ID: %v", err) } - funnels, err := module.store.List(ctx) + funnels, err := module.store.List(ctx, orgUUID) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to list funnels: %v", err) } - // Filter by orgID - var orgFunnels []*traceFunnels.Funnel - for _, f := range funnels { - if f.OrgID == orgUUID { - orgFunnels = append(orgFunnels, f) - } - } - - return orgFunnels, nil + return funnels, nil } // Delete deletes a funnel diff --git a/pkg/modules/tracefunnel/impltracefunnel/module_test.go b/pkg/modules/tracefunnel/impltracefunnel/module_test.go index fef1bc9b2b..2a6560cd43 100644 --- a/pkg/modules/tracefunnel/impltracefunnel/module_test.go +++ b/pkg/modules/tracefunnel/impltracefunnel/module_test.go @@ -26,8 +26,8 @@ func (m *MockStore) Get(ctx context.Context, uuid valuer.UUID) (*traceFunnels.Fu return args.Get(0).(*traceFunnels.Funnel), args.Error(1) } -func (m *MockStore) List(ctx context.Context) ([]*traceFunnels.Funnel, error) { - args := m.Called(ctx) +func (m *MockStore) List(ctx context.Context, orgID valuer.UUID) ([]*traceFunnels.Funnel, error) { + args := m.Called(ctx, orgID) return args.Get(0).([]*traceFunnels.Funnel), args.Error(1) } @@ -51,7 +51,14 @@ func TestModule_Create(t *testing.T) { userID := "user-123" orgID := valuer.GenerateUUID().String() - mockStore.On("Create", ctx, mock.AnythingOfType("*tracefunnels.Funnel")).Return(nil) + mockStore.On("Create", ctx, mock.MatchedBy(func(f *traceFunnels.Funnel) bool { + return f.Name == name && + f.CreatedBy == userID && + f.OrgID.String() == orgID && + f.CreatedByUser != nil && + f.CreatedByUser.ID == userID && + f.CreatedAt.UnixNano()/1000000 == timestamp + })).Return(nil) funnel, err := module.Create(ctx, timestamp, name, userID, orgID) assert.NoError(t, err) @@ -59,6 +66,8 @@ func TestModule_Create(t *testing.T) { assert.Equal(t, name, funnel.Name) assert.Equal(t, userID, funnel.CreatedBy) assert.Equal(t, orgID, funnel.OrgID.String()) + assert.NotNil(t, funnel.CreatedByUser) + assert.Equal(t, userID, funnel.CreatedByUser.ID) mockStore.AssertExpectations(t) } @@ -111,22 +120,23 @@ func TestModule_List(t *testing.T) { ctx := context.Background() orgID := valuer.GenerateUUID().String() + orgUUID := valuer.MustNewUUID(orgID) expectedFunnels := []*traceFunnels.Funnel{ { BaseMetadata: traceFunnels.BaseMetadata{ Name: "funnel-1", - OrgID: valuer.MustNewUUID(orgID), + OrgID: orgUUID, }, }, { BaseMetadata: traceFunnels.BaseMetadata{ Name: "funnel-2", - OrgID: valuer.MustNewUUID(orgID), + OrgID: orgUUID, }, }, } - mockStore.On("List", ctx).Return(expectedFunnels, nil) + mockStore.On("List", ctx, orgUUID).Return(expectedFunnels, nil) funnels, err := module.List(ctx, orgID) assert.NoError(t, err) diff --git a/pkg/modules/tracefunnel/impltracefunnel/store.go b/pkg/modules/tracefunnel/impltracefunnel/store.go index d6421790aa..c7ddd048d4 100644 --- a/pkg/modules/tracefunnel/impltracefunnel/store.go +++ b/pkg/modules/tracefunnel/impltracefunnel/store.go @@ -88,8 +88,8 @@ func (store *store) Update(ctx context.Context, funnel *traceFunnels.Funnel) err return nil } -// List retrieves all funnels -func (store *store) List(ctx context.Context) ([]*traceFunnels.Funnel, error) { +// List retrieves all funnels for a given organization +func (store *store) List(ctx context.Context, orgID valuer.UUID) ([]*traceFunnels.Funnel, error) { var funnels []*traceFunnels.Funnel err := store. sqlstore. @@ -97,6 +97,7 @@ func (store *store) List(ctx context.Context) ([]*traceFunnels.Funnel, error) { NewSelect(). Model(&funnels). Relation("CreatedByUser"). + Where("?TableAlias.org_id = ?", orgID). Scan(ctx) if err != nil { return nil, fmt.Errorf("failed to list funnels: %v", err) diff --git a/pkg/types/tracefunnel/store.go b/pkg/types/tracefunnel/store.go index 6b8f1b56b2..618477f324 100644 --- a/pkg/types/tracefunnel/store.go +++ b/pkg/types/tracefunnel/store.go @@ -9,7 +9,7 @@ import ( type TraceFunnelStore interface { Create(context.Context, *Funnel) error Get(context.Context, valuer.UUID) (*Funnel, error) - List(context.Context) ([]*Funnel, error) + List(context.Context, valuer.UUID) ([]*Funnel, error) Update(context.Context, *Funnel) error Delete(context.Context, valuer.UUID) error }