diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go index f89f85c67f..95a4ae0788 100644 --- a/ee/query-service/app/api/pat.go +++ b/ee/query-service/app/api/pat.go @@ -10,9 +10,12 @@ import ( "github.com/SigNoz/signoz/ee/query-service/model" "github.com/SigNoz/signoz/ee/types" eeTypes "github.com/SigNoz/signoz/ee/types" + "github.com/SigNoz/signoz/pkg/errors" + "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/query-service/auth" baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants" basemodel "github.com/SigNoz/signoz/pkg/query-service/model" + "github.com/SigNoz/signoz/pkg/valuer" "github.com/gorilla/mux" "go.uber.org/zap" ) @@ -93,7 +96,12 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { } req.UpdatedByUserID = user.ID - id := mux.Vars(r)["id"] + idStr := mux.Vars(r)["id"] + id, err := valuer.NewUUID(idStr) + if err != nil { + render.Error(w, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "id is not a valid uuid-v7")) + return + } req.UpdatedAt = time.Now() zap.L().Info("Got Update PAT request", zap.Any("pat", req)) var apierr basemodel.BaseApiError @@ -126,7 +134,12 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) { ctx := context.Background() - id := mux.Vars(r)["id"] + idStr := mux.Vars(r)["id"] + id, err := valuer.NewUUID(idStr) + if err != nil { + render.Error(w, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "id is not a valid uuid-v7")) + return + } user, err := auth.GetUserFromReqContext(r.Context()) if err != nil { RespondError(w, &model.ApiError{ @@ -136,7 +149,7 @@ func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) { return } - zap.L().Info("Revoke PAT with id", zap.String("id", id)) + zap.L().Info("Revoke PAT with id", zap.String("id", id.StringValue())) if apierr := ah.AppDao().RevokePAT(ctx, user.OrgID, id, user.ID); apierr != nil { RespondError(w, apierr, nil) return diff --git a/ee/query-service/dao/interface.go b/ee/query-service/dao/interface.go index 0ee500a327..6c5bc1e612 100644 --- a/ee/query-service/dao/interface.go +++ b/ee/query-service/dao/interface.go @@ -10,6 +10,7 @@ import ( basemodel "github.com/SigNoz/signoz/pkg/query-service/model" ossTypes "github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types/authtypes" + "github.com/SigNoz/signoz/pkg/valuer" "github.com/google/uuid" "github.com/uptrace/bun" ) @@ -36,10 +37,10 @@ type ModelDao interface { GetDomainByEmail(ctx context.Context, email string) (*types.GettableOrgDomain, basemodel.BaseApiError) CreatePAT(ctx context.Context, orgID string, p types.GettablePAT) (types.GettablePAT, basemodel.BaseApiError) - UpdatePAT(ctx context.Context, orgID string, p types.GettablePAT, id string) basemodel.BaseApiError + UpdatePAT(ctx context.Context, orgID string, p types.GettablePAT, id valuer.UUID) basemodel.BaseApiError GetPAT(ctx context.Context, pat string) (*types.GettablePAT, basemodel.BaseApiError) - GetPATByID(ctx context.Context, orgID string, id string) (*types.GettablePAT, basemodel.BaseApiError) + GetPATByID(ctx context.Context, orgID string, id valuer.UUID) (*types.GettablePAT, basemodel.BaseApiError) GetUserByPAT(ctx context.Context, orgID string, token string) (*ossTypes.GettableUser, basemodel.BaseApiError) ListPATs(ctx context.Context, orgID string) ([]types.GettablePAT, basemodel.BaseApiError) - RevokePAT(ctx context.Context, orgID string, id string, userID string) basemodel.BaseApiError + RevokePAT(ctx context.Context, orgID string, id valuer.UUID, userID string) basemodel.BaseApiError } diff --git a/ee/query-service/dao/sqlite/pat.go b/ee/query-service/dao/sqlite/pat.go index e1a08f8d40..be51b716f5 100644 --- a/ee/query-service/dao/sqlite/pat.go +++ b/ee/query-service/dao/sqlite/pat.go @@ -9,12 +9,14 @@ import ( "github.com/SigNoz/signoz/ee/types" basemodel "github.com/SigNoz/signoz/pkg/query-service/model" ossTypes "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/valuer" "go.uber.org/zap" ) func (m *modelDao) CreatePAT(ctx context.Context, orgID string, p types.GettablePAT) (types.GettablePAT, basemodel.BaseApiError) { p.StorablePersonalAccessToken.OrgID = orgID + p.StorablePersonalAccessToken.ID = valuer.GenerateUUID() _, err := m.DB().NewInsert(). Model(&p.StorablePersonalAccessToken). Exec(ctx) @@ -46,11 +48,11 @@ func (m *modelDao) CreatePAT(ctx context.Context, orgID string, p types.Gettable return p, nil } -func (m *modelDao) UpdatePAT(ctx context.Context, orgID string, p types.GettablePAT, id string) basemodel.BaseApiError { +func (m *modelDao) UpdatePAT(ctx context.Context, orgID string, p types.GettablePAT, id valuer.UUID) basemodel.BaseApiError { _, err := m.DB().NewUpdate(). Model(&p.StorablePersonalAccessToken). Column("role", "name", "updated_at", "updated_by_user_id"). - Where("id = ?", id). + Where("id = ?", id.StringValue()). Where("org_id = ?", orgID). Where("revoked = false"). Exec(ctx) @@ -127,14 +129,14 @@ func (m *modelDao) ListPATs(ctx context.Context, orgID string) ([]types.Gettable return patsWithUsers, nil } -func (m *modelDao) RevokePAT(ctx context.Context, orgID string, id string, userID string) basemodel.BaseApiError { +func (m *modelDao) RevokePAT(ctx context.Context, orgID string, id valuer.UUID, userID string) basemodel.BaseApiError { updatedAt := time.Now().Unix() _, err := m.DB().NewUpdate(). Model(&types.StorablePersonalAccessToken{}). Set("revoked = ?", true). Set("updated_by_user_id = ?", userID). Set("updated_at = ?", updatedAt). - Where("id = ?", id). + Where("id = ?", id.StringValue()). Where("org_id = ?", orgID). Exec(ctx) if err != nil { @@ -169,12 +171,12 @@ func (m *modelDao) GetPAT(ctx context.Context, token string) (*types.GettablePAT return &patWithUser, nil } -func (m *modelDao) GetPATByID(ctx context.Context, orgID string, id string) (*types.GettablePAT, basemodel.BaseApiError) { +func (m *modelDao) GetPATByID(ctx context.Context, orgID string, id valuer.UUID) (*types.GettablePAT, basemodel.BaseApiError) { pats := []types.StorablePersonalAccessToken{} if err := m.DB().NewSelect(). Model(&pats). - Where("id = ?", id). + Where("id = ?", id.StringValue()). Where("org_id = ?", orgID). Where("revoked = false"). Scan(ctx); err != nil { diff --git a/ee/sqlstore/postgressqlstore/dialect.go b/ee/sqlstore/postgressqlstore/dialect.go index fead4e0fa7..c15fbe8d56 100644 --- a/ee/sqlstore/postgressqlstore/dialect.go +++ b/ee/sqlstore/postgressqlstore/dialect.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "slices" "github.com/SigNoz/signoz/pkg/errors" "github.com/uptrace/bun" @@ -192,7 +193,10 @@ func (dialect *dialect) TableExists(ctx context.Context, bun bun.IDB, table inte return true, nil } -func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.IDB, oldModel interface{}, newModel interface{}, cb func(context.Context) error) error { +func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.IDB, oldModel interface{}, newModel interface{}, references []string, cb func(context.Context) error) error { + if len(references) == 0 { + return errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "cannot run migration without reference") + } exists, err := dialect.TableExists(ctx, bun, newModel) if err != nil { return err @@ -201,12 +205,25 @@ func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.I return nil } - _, err = bun. + var fkReferences []string + for _, reference := range references { + if reference == Org && !slices.Contains(fkReferences, OrgReference) { + fkReferences = append(fkReferences, OrgReference) + } else if reference == User && !slices.Contains(fkReferences, UserReference) { + fkReferences = append(fkReferences, UserReference) + } + } + + createTable := bun. NewCreateTable(). IfNotExists(). - Model(newModel). - Exec(ctx) + Model(newModel) + for _, fk := range fkReferences { + createTable = createTable.ForeignKey(fk) + } + + _, err = createTable.Exec(ctx) if err != nil { return err } diff --git a/ee/types/personal_access_token.go b/ee/types/personal_access_token.go index d791cba2a8..ff666740f9 100644 --- a/ee/types/personal_access_token.go +++ b/ee/types/personal_access_token.go @@ -6,6 +6,7 @@ import ( "time" "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/valuer" "github.com/uptrace/bun" ) @@ -28,11 +29,10 @@ func NewGettablePAT(name, role, userID string, expiresAt int64) GettablePAT { } type StorablePersonalAccessToken struct { - bun.BaseModel `bun:"table:personal_access_tokens"` - + bun.BaseModel `bun:"table:personal_access_token"` + types.Identifiable types.TimeAuditable OrgID string `json:"orgId" bun:"org_id,type:text,notnull"` - ID int `json:"id" bun:"id,pk,autoincrement"` Role string `json:"role" bun:"role,type:text,notnull,default:'ADMIN'"` UserID string `json:"userId" bun:"user_id,type:text,notnull"` Token string `json:"token" bun:"token,type:text,notnull,unique"` @@ -69,5 +69,8 @@ func NewStorablePersonalAccessToken(name, role, userID string, expiresAt int64) CreatedAt: now, UpdatedAt: now, }, + Identifiable: types.Identifiable{ + ID: valuer.GenerateUUID(), + }, } } diff --git a/pkg/query-service/auth/auth.go b/pkg/query-service/auth/auth.go index 93101e9a5f..2b4e349d5a 100644 --- a/pkg/query-service/auth/auth.go +++ b/pkg/query-service/auth/auth.go @@ -329,6 +329,9 @@ func CreateResetPasswordToken(ctx context.Context, userId string) (*types.ResetP } req := &types.ResetPasswordRequest{ + Identifiable: types.Identifiable{ + ID: valuer.GenerateUUID(), + }, UserID: userId, Token: token, } diff --git a/pkg/query-service/telemetry/telemetry.go b/pkg/query-service/telemetry/telemetry.go index de36970c68..ede62dcfa6 100644 --- a/pkg/query-service/telemetry/telemetry.go +++ b/pkg/query-service/telemetry/telemetry.go @@ -317,6 +317,7 @@ func createTelemetry() { getLogsInfoInLastHeartBeatInterval, _ := telemetry.reader.GetLogsInfoInLastHeartBeatInterval(ctx, HEART_BEAT_DURATION) + // TODO update this post bootstrap decision traceTTL, _ := telemetry.reader.GetTTL(ctx, "", &model.GetTTLParams{Type: constants.TraceTTL}) metricsTTL, _ := telemetry.reader.GetTTL(ctx, "", &model.GetTTLParams{Type: constants.MetricsTTL}) logsTTL, _ := telemetry.reader.GetTTL(ctx, "", &model.GetTTLParams{Type: constants.LogsTTL}) diff --git a/pkg/signoz/provider.go b/pkg/signoz/provider.go index 7af36f7830..eec3a8bcc5 100644 --- a/pkg/signoz/provider.go +++ b/pkg/signoz/provider.go @@ -68,6 +68,7 @@ func NewSQLMigrationProviderFactories(sqlstore sqlstore.SQLStore) factory.NamedM sqlmigration.NewUpdateAlertmanagerFactory(sqlstore), sqlmigration.NewUpdatePreferencesFactory(sqlstore), sqlmigration.NewUpdateApdexTtlFactory(sqlstore), + sqlmigration.NewUpdateResetPasswordFactory(sqlstore), ) } diff --git a/pkg/sqlmigration/019_update_invites.go b/pkg/sqlmigration/019_update_invites.go index 574216a67b..cb7a78ddbc 100644 --- a/pkg/sqlmigration/019_update_invites.go +++ b/pkg/sqlmigration/019_update_invites.go @@ -75,7 +75,7 @@ func (migration *updateInvites) Up(ctx context.Context, db *bun.DB) error { err = migration. store. Dialect(). - RenameTableAndModifyModel(ctx, tx, new(existingInvite), new(newInvite), func(ctx context.Context) error { + RenameTableAndModifyModel(ctx, tx, new(existingInvite), new(newInvite), []string{OrgReference}, func(ctx context.Context) error { existingInvites := make([]*existingInvite, 0) err = tx. NewSelect(). diff --git a/pkg/sqlmigration/021_update_alertmanager.go b/pkg/sqlmigration/021_update_alertmanager.go index 51561534e0..49ce90ce0c 100644 --- a/pkg/sqlmigration/021_update_alertmanager.go +++ b/pkg/sqlmigration/021_update_alertmanager.go @@ -110,7 +110,7 @@ func (migration *updateAlertmanager) Up(ctx context.Context, db *bun.DB) error { err = migration. store. Dialect(). - RenameTableAndModifyModel(ctx, tx, new(existingChannel), new(newChannel), func(ctx context.Context) error { + RenameTableAndModifyModel(ctx, tx, new(existingChannel), new(newChannel), []string{OrgReference}, func(ctx context.Context) error { existingChannels := make([]*existingChannel, 0) err = tx. NewSelect(). diff --git a/pkg/sqlmigration/023_update_apdex_ttl.go b/pkg/sqlmigration/023_update_apdex_ttl.go index 02294d5c3e..7f78e4d569 100644 --- a/pkg/sqlmigration/023_update_apdex_ttl.go +++ b/pkg/sqlmigration/023_update_apdex_ttl.go @@ -94,7 +94,7 @@ func (migration *updateApdexTtl) Up(ctx context.Context, db *bun.DB) error { err = migration. store. Dialect(). - RenameTableAndModifyModel(ctx, tx, new(existingApdexSettings), new(newApdexSettings), func(ctx context.Context) error { + RenameTableAndModifyModel(ctx, tx, new(existingApdexSettings), new(newApdexSettings), []string{OrgReference}, func(ctx context.Context) error { existingApdexSettings := make([]*existingApdexSettings, 0) err = tx. NewSelect(). @@ -133,7 +133,7 @@ func (migration *updateApdexTtl) Up(ctx context.Context, db *bun.DB) error { err = migration. store. Dialect(). - RenameTableAndModifyModel(ctx, tx, new(existingTTLStatus), new(newTTLStatus), func(ctx context.Context) error { + RenameTableAndModifyModel(ctx, tx, new(existingTTLStatus), new(newTTLStatus), []string{OrgReference}, func(ctx context.Context) error { existingTTLStatus := make([]*existingTTLStatus, 0) err = tx. NewSelect(). diff --git a/pkg/sqlmigration/024_update_reset_password.go b/pkg/sqlmigration/024_update_reset_password.go new file mode 100644 index 0000000000..613c787986 --- /dev/null +++ b/pkg/sqlmigration/024_update_reset_password.go @@ -0,0 +1,200 @@ +package sqlmigration + +import ( + "context" + "database/sql" + + "github.com/SigNoz/signoz/pkg/factory" + "github.com/SigNoz/signoz/pkg/sqlstore" + "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/valuer" + "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate" +) + +type updateResetPassword struct { + store sqlstore.SQLStore +} + +type existingResetPasswordRequest struct { + bun.BaseModel `bun:"table:reset_password_request"` + ID int `bun:"id,pk,autoincrement" json:"id"` + Token string `bun:"token,type:text,notnull" json:"token"` + UserID string `bun:"user_id,type:text,notnull" json:"userId"` +} + +type newResetPasswordRequest struct { + bun.BaseModel `bun:"table:reset_password_request_new"` + types.Identifiable + Token string `bun:"token,type:text,notnull" json:"token"` + UserID string `bun:"user_id,type:text,notnull" json:"userId"` +} + +type existingPersonalAccessToken struct { + bun.BaseModel `bun:"table:personal_access_tokens"` + types.TimeAuditable + OrgID string `json:"orgId" bun:"org_id,type:text,notnull"` + ID int `json:"id" bun:"id,pk,autoincrement"` + Role string `json:"role" bun:"role,type:text,notnull,default:'ADMIN'"` + UserID string `json:"userId" bun:"user_id,type:text,notnull"` + Token string `json:"token" bun:"token,type:text,notnull,unique"` + Name string `json:"name" bun:"name,type:text,notnull"` + ExpiresAt int64 `json:"expiresAt" bun:"expires_at,notnull,default:0"` + LastUsed int64 `json:"lastUsed" bun:"last_used,notnull,default:0"` + Revoked bool `json:"revoked" bun:"revoked,notnull,default:false"` + UpdatedByUserID string `json:"updatedByUserId" bun:"updated_by_user_id,type:text,notnull,default:''"` +} + +type newPersonalAccessToken struct { + bun.BaseModel `bun:"table:personal_access_token"` + types.Identifiable + types.TimeAuditable + OrgID string `json:"orgId" bun:"org_id,type:text,notnull"` + Role string `json:"role" bun:"role,type:text,notnull,default:'ADMIN'"` + UserID string `json:"userId" bun:"user_id,type:text,notnull"` + Token string `json:"token" bun:"token,type:text,notnull,unique"` + Name string `json:"name" bun:"name,type:text,notnull"` + ExpiresAt int64 `json:"expiresAt" bun:"expires_at,notnull,default:0"` + LastUsed int64 `json:"lastUsed" bun:"last_used,notnull,default:0"` + Revoked bool `json:"revoked" bun:"revoked,notnull,default:false"` + UpdatedByUserID string `json:"updatedByUserId" bun:"updated_by_user_id,type:text,notnull,default:''"` +} + +func NewUpdateResetPasswordFactory(sqlstore sqlstore.SQLStore) factory.ProviderFactory[SQLMigration, Config] { + return factory. + NewProviderFactory( + factory.MustNewName("update_reset_password"), + func(ctx context.Context, ps factory.ProviderSettings, c Config) (SQLMigration, error) { + return newUpdateResetPassword(ctx, ps, c, sqlstore) + }) +} + +func newUpdateResetPassword(_ context.Context, _ factory.ProviderSettings, _ Config, store sqlstore.SQLStore) (SQLMigration, error) { + return &updateResetPassword{store: store}, nil +} + +func (migration *updateResetPassword) Register(migrations *migrate.Migrations) error { + if err := migrations. + Register(migration.Up, migration.Down); err != nil { + return err + } + + return nil +} + +func (migration *updateResetPassword) Up(ctx context.Context, db *bun.DB) error { + tx, err := db. + BeginTx(ctx, nil) + if err != nil { + return err + } + + defer tx.Rollback() + + err = migration.store.Dialect().UpdatePrimaryKey(ctx, tx, new(existingResetPasswordRequest), new(newResetPasswordRequest), UserReference, func(ctx context.Context) error { + existingResetPasswordRequests := make([]*existingResetPasswordRequest, 0) + err = tx. + NewSelect(). + Model(&existingResetPasswordRequests). + Scan(ctx) + if err != nil { + if err != sql.ErrNoRows { + return err + } + } + + if err == nil && len(existingResetPasswordRequests) > 0 { + newResetPasswordRequests := migration. + CopyExistingResetPasswordRequestsToNewResetPasswordRequests(existingResetPasswordRequests) + _, err = tx. + NewInsert(). + Model(&newResetPasswordRequests). + Exec(ctx) + if err != nil { + return err + } + } + return nil + }) + + if err != nil { + return err + } + + err = migration.store.Dialect().RenameTableAndModifyModel(ctx, tx, new(existingPersonalAccessToken), new(newPersonalAccessToken), []string{OrgReference, UserReference}, func(ctx context.Context) error { + existingPersonalAccessTokens := make([]*existingPersonalAccessToken, 0) + err = tx. + NewSelect(). + Model(&existingPersonalAccessTokens). + Scan(ctx) + if err != nil { + if err != sql.ErrNoRows { + return err + } + } + + if err == nil && len(existingPersonalAccessTokens) > 0 { + newPersonalAccessTokens := migration. + CopyExistingPATsToNewPATs(existingPersonalAccessTokens) + _, err = tx.NewInsert().Model(&newPersonalAccessTokens).Exec(ctx) + if err != nil { + return err + } + } + return nil + }) + + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func (migration *updateResetPassword) Down(context.Context, *bun.DB) error { + return nil +} + +func (migration *updateResetPassword) CopyExistingResetPasswordRequestsToNewResetPasswordRequests(existingPasswordRequests []*existingResetPasswordRequest) []*newResetPasswordRequest { + newResetPasswordRequests := make([]*newResetPasswordRequest, 0) + for _, request := range existingPasswordRequests { + newResetPasswordRequests = append(newResetPasswordRequests, &newResetPasswordRequest{ + Identifiable: types.Identifiable{ + ID: valuer.GenerateUUID(), + }, + Token: request.Token, + UserID: request.UserID, + }) + } + return newResetPasswordRequests +} + +func (migration *updateResetPassword) CopyExistingPATsToNewPATs(existingPATs []*existingPersonalAccessToken) []*newPersonalAccessToken { + newPATs := make([]*newPersonalAccessToken, 0) + for _, pat := range existingPATs { + newPATs = append(newPATs, &newPersonalAccessToken{ + Identifiable: types.Identifiable{ + ID: valuer.GenerateUUID(), + }, + TimeAuditable: types.TimeAuditable{ + CreatedAt: pat.CreatedAt, + UpdatedAt: pat.UpdatedAt, + }, + Role: pat.Role, + Name: pat.Name, + ExpiresAt: pat.ExpiresAt, + LastUsed: pat.LastUsed, + UserID: pat.UserID, + Token: pat.Token, + Revoked: pat.Revoked, + UpdatedByUserID: pat.UpdatedByUserID, + OrgID: pat.OrgID, + }) + } + return newPATs +} diff --git a/pkg/sqlstore/sqlitesqlstore/dialect.go b/pkg/sqlstore/sqlitesqlstore/dialect.go index 688a0195d2..e15976f914 100644 --- a/pkg/sqlstore/sqlitesqlstore/dialect.go +++ b/pkg/sqlstore/sqlitesqlstore/dialect.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "slices" "github.com/SigNoz/signoz/pkg/errors" "github.com/uptrace/bun" @@ -183,7 +184,10 @@ func (dialect *dialect) TableExists(ctx context.Context, bun bun.IDB, table inte return true, nil } -func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.IDB, oldModel interface{}, newModel interface{}, cb func(context.Context) error) error { +func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.IDB, oldModel interface{}, newModel interface{}, references []string, cb func(context.Context) error) error { + if len(references) == 0 { + return errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "cannot run migration without reference") + } exists, err := dialect.TableExists(ctx, bun, newModel) if err != nil { return err @@ -192,13 +196,25 @@ func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.I return nil } - _, err = bun. + var fkReferences []string + for _, reference := range references { + if reference == Org && !slices.Contains(fkReferences, OrgReference) { + fkReferences = append(fkReferences, OrgReference) + } else if reference == User && !slices.Contains(fkReferences, UserReference) { + fkReferences = append(fkReferences, UserReference) + } + } + + createTable := bun. NewCreateTable(). IfNotExists(). - Model(newModel). - ForeignKey(`("org_id") REFERENCES "organizations" ("id")`). - Exec(ctx) + Model(newModel) + for _, fk := range fkReferences { + createTable = createTable.ForeignKey(fk) + } + + _, err = createTable.Exec(ctx) if err != nil { return err } diff --git a/pkg/sqlstore/sqlstore.go b/pkg/sqlstore/sqlstore.go index 254388ce06..6a9513e80a 100644 --- a/pkg/sqlstore/sqlstore.go +++ b/pkg/sqlstore/sqlstore.go @@ -43,7 +43,7 @@ type SQLDialect interface { GetColumnType(context.Context, bun.IDB, string, string) (string, error) ColumnExists(context.Context, bun.IDB, string, string) (bool, error) RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error) - RenameTableAndModifyModel(context.Context, bun.IDB, interface{}, interface{}, func(context.Context) error) error + RenameTableAndModifyModel(context.Context, bun.IDB, interface{}, interface{}, []string, func(context.Context) error) error UpdatePrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error AddPrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error } diff --git a/pkg/sqlstore/sqlstoretest/dialect.go b/pkg/sqlstore/sqlstoretest/dialect.go index b3d09183fb..f30a3c7915 100644 --- a/pkg/sqlstore/sqlstoretest/dialect.go +++ b/pkg/sqlstore/sqlstoretest/dialect.go @@ -29,7 +29,7 @@ func (dialect *dialect) RenameColumn(ctx context.Context, bun bun.IDB, table str return true, nil } -func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.IDB, oldModel interface{}, newModel interface{}, cb func(context.Context) error) error { +func (dialect *dialect) RenameTableAndModifyModel(ctx context.Context, bun bun.IDB, oldModel interface{}, newModel interface{}, references []string, cb func(context.Context) error) error { return nil } diff --git a/pkg/types/user.go b/pkg/types/user.go index 33afceccb6..ba9e1eb965 100644 --- a/pkg/types/user.go +++ b/pkg/types/user.go @@ -46,7 +46,7 @@ type User struct { type ResetPasswordRequest struct { bun.BaseModel `bun:"table:reset_password_request"` - ID int `bun:"id,pk,autoincrement" json:"id"` - Token string `bun:"token,type:text,notnull" json:"token"` - UserID string `bun:"user_id,type:text,notnull" json:"userId"` + Identifiable + Token string `bun:"token,type:text,notnull" json:"token"` + UserID string `bun:"user_id,type:text,notnull" json:"userId"` }