feat(sqlstore): add transaction support for sqlstore (#7224)

### Summary

- add transaction support for sqlstore
- use transactions in alertmanager
This commit is contained in:
Vibhu Pandey 2025-03-05 18:50:48 +05:30 committed by GitHub
parent c2d038c025
commit 02865cf49e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 296 additions and 145 deletions

View File

@ -314,7 +314,7 @@ func (server *Server) SetConfig(ctx context.Context, alertmanagerConfig *alertma
}
func (server *Server) TestReceiver(ctx context.Context, receiver alertmanagertypes.Receiver) error {
return alertmanagertypes.TestReceiver(ctx, receiver, server.tmpl, server.logger, alertmanagertypes.NewTestAlert(receiver, time.Now(), time.Now()))
return alertmanagertypes.TestReceiver(ctx, receiver, server.alertmanagerConfig, server.tmpl, server.logger, alertmanagertypes.NewTestAlert(receiver, time.Now(), time.Now()))
}
func (server *Server) TestAlert(ctx context.Context, postableAlert *alertmanagertypes.PostableAlert, receivers []string) error {
@ -335,7 +335,7 @@ func (server *Server) TestAlert(ctx context.Context, postableAlert *alertmanager
ch <- err
return
}
ch <- alertmanagertypes.TestReceiver(ctx, receiver, server.tmpl, server.logger, alerts[0])
ch <- alertmanagertypes.TestReceiver(ctx, receiver, server.alertmanagerConfig, server.tmpl, server.logger, alerts[0])
}(receiverName)
}

View File

@ -48,21 +48,23 @@ func (store *config) Get(ctx context.Context, orgID string) (*alertmanagertypes.
}
// Set implements alertmanagertypes.ConfigStore.
func (store *config) Set(ctx context.Context, config *alertmanagertypes.Config) error {
if _, err := store.
sqlstore.
BunDB().
NewInsert().
Model(config.StoreableConfig()).
On("CONFLICT (org_id) DO UPDATE").
Set("config = ?", config.StoreableConfig().Config).
Set("hash = ?", config.StoreableConfig().Hash).
Set("updated_at = ?", config.StoreableConfig().UpdatedAt).
Exec(ctx); err != nil {
return err
}
func (store *config) Set(ctx context.Context, config *alertmanagertypes.Config, opts ...alertmanagertypes.StoreOption) error {
return store.wrap(ctx, func(ctx context.Context) error {
if _, err := store.
sqlstore.
BunDBCtx(ctx).
NewInsert().
Model(config.StoreableConfig()).
On("CONFLICT (org_id) DO UPDATE").
Set("config = ?", config.StoreableConfig().Config).
Set("hash = ?", config.StoreableConfig().Hash).
Set("updated_at = ?", config.StoreableConfig().UpdatedAt).
Exec(ctx); err != nil {
return err
}
return nil
return nil
}, opts...)
}
func (store *config) ListOrgs(ctx context.Context) ([]string, error) {
@ -82,31 +84,19 @@ func (store *config) ListOrgs(ctx context.Context) ([]string, error) {
return orgIDs, nil
}
func (store *config) CreateChannel(ctx context.Context, channel *alertmanagertypes.Channel, cb func(context.Context) error) error {
tx, err := store.sqlstore.BunDB().BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() //nolint:errcheck
if _, err = tx.NewInsert().
Model(channel).
Exec(ctx); err != nil {
return err
}
if cb != nil {
if err = cb(ctx); err != nil {
func (store *config) CreateChannel(ctx context.Context, channel *alertmanagertypes.Channel, opts ...alertmanagertypes.StoreOption) error {
return store.wrap(ctx, func(ctx context.Context) error {
if _, err := store.
sqlstore.
BunDBCtx(ctx).
NewInsert().
Model(channel).
Exec(ctx); err != nil {
return err
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
return nil
}, opts...)
}
func (store *config) GetChannelByID(ctx context.Context, orgID string, id int) (*alertmanagertypes.Channel, error) {
@ -130,65 +120,39 @@ func (store *config) GetChannelByID(ctx context.Context, orgID string, id int) (
return channel, nil
}
func (store *config) UpdateChannel(ctx context.Context, orgID string, channel *alertmanagertypes.Channel, cb func(context.Context) error) error {
tx, err := store.sqlstore.BunDB().BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() //nolint:errcheck
_, err = tx.NewUpdate().
Model(channel).
WherePK().
Exec(ctx)
if err != nil {
return err
}
if cb != nil {
if err = cb(ctx); err != nil {
func (store *config) UpdateChannel(ctx context.Context, orgID string, channel *alertmanagertypes.Channel, opts ...alertmanagertypes.StoreOption) error {
return store.wrap(ctx, func(ctx context.Context) error {
if _, err := store.
sqlstore.
BunDBCtx(ctx).
NewUpdate().
Model(channel).
WherePK().
Exec(ctx); err != nil {
return err
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
return nil
}, opts...)
}
func (store *config) DeleteChannelByID(ctx context.Context, orgID string, id int, cb func(context.Context) error) error {
channel := new(alertmanagertypes.Channel)
func (store *config) DeleteChannelByID(ctx context.Context, orgID string, id int, opts ...alertmanagertypes.StoreOption) error {
return store.wrap(ctx, func(ctx context.Context) error {
channel := new(alertmanagertypes.Channel)
tx, err := store.sqlstore.BunDB().BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() //nolint:errcheck
_, err = tx.NewDelete().
Model(channel).
Where("org_id = ?", orgID).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return err
}
if cb != nil {
if err = cb(ctx); err != nil {
if _, err := store.
sqlstore.
BunDBCtx(ctx).
NewDelete().
Model(channel).
Where("org_id = ?", orgID).
Where("id = ?", id).
Exec(ctx); err != nil {
return err
}
}
if err = tx.Commit(); err != nil {
return err
}
return nil
return nil
}, opts...)
}
func (store *config) ListChannels(ctx context.Context, orgID string) ([]*alertmanagertypes.Channel, error) {
@ -254,3 +218,19 @@ func (store *config) GetMatchers(ctx context.Context, orgID string) (map[string]
return matchersMap, nil
}
func (store *config) wrap(ctx context.Context, fn func(ctx context.Context) error, opts ...alertmanagertypes.StoreOption) error {
storeOpts := alertmanagertypes.NewStoreOptions(opts...)
if storeOpts.Cb == nil {
return fn(ctx)
}
return store.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
if err := fn(ctx); err != nil {
return err
}
return storeOpts.Cb(ctx)
})
}

View File

@ -252,7 +252,16 @@ func (provider *provider) UpdateChannelByReceiverAndID(ctx context.Context, orgI
return err
}
err = provider.configStore.UpdateChannel(ctx, orgID, channel, func(ctx context.Context) error {
config, err := provider.configStore.Get(ctx, orgID)
if err != nil {
return err
}
if err := config.UpdateReceiver(receiver); err != nil {
return err
}
err = provider.configStore.UpdateChannel(ctx, orgID, channel, alertmanagertypes.WithCb(func(ctx context.Context) error {
url := provider.url.JoinPath(routesPath)
body, err := json.Marshal(receiver)
@ -278,8 +287,12 @@ func (provider *provider) UpdateChannelByReceiverAndID(ctx context.Context, orgI
return fmt.Errorf("bad response status %v", resp.Status)
}
if err := provider.configStore.Set(ctx, config); err != nil {
return err
}
return nil
})
}))
if err != nil {
return err
}
@ -290,7 +303,16 @@ func (provider *provider) UpdateChannelByReceiverAndID(ctx context.Context, orgI
func (provider *provider) CreateChannel(ctx context.Context, orgID string, receiver alertmanagertypes.Receiver) error {
channel := alertmanagertypes.NewChannelFromReceiver(receiver, orgID)
err := provider.configStore.CreateChannel(ctx, channel, func(ctx context.Context) error {
config, err := provider.configStore.Get(ctx, orgID)
if err != nil {
return err
}
if err := config.CreateReceiver(receiver); err != nil {
return err
}
return provider.configStore.CreateChannel(ctx, channel, alertmanagertypes.WithCb(func(ctx context.Context) error {
url := provider.url.JoinPath(routesPath)
body, err := json.Marshal(receiver)
@ -316,22 +338,30 @@ func (provider *provider) CreateChannel(ctx context.Context, orgID string, recei
return fmt.Errorf("bad response status %v", resp.Status)
}
if err := provider.configStore.Set(ctx, config); err != nil {
return err
}
return nil
})
}))
}
func (provider *provider) DeleteChannelByID(ctx context.Context, orgID string, channelID int) error {
channel, err := provider.configStore.GetChannelByID(ctx, orgID, channelID)
if err != nil {
return err
}
return nil
}
config, err := provider.configStore.Get(ctx, orgID)
if err != nil {
return err
}
func (provider *provider) DeleteChannelByID(ctx context.Context, orgID string, channelID int) error {
err := provider.configStore.DeleteChannelByID(ctx, orgID, channelID, func(ctx context.Context) error {
channel, err := provider.configStore.GetChannelByID(ctx, orgID, channelID)
if err != nil {
return err
}
if err := config.DeleteReceiver(channel.Name); err != nil {
return err
}
return provider.configStore.DeleteChannelByID(ctx, orgID, channelID, alertmanagertypes.WithCb(func(ctx context.Context) error {
url := provider.url.JoinPath(routesPath)
body, err := json.Marshal(map[string]string{"name": channel.Name})
@ -357,13 +387,12 @@ func (provider *provider) DeleteChannelByID(ctx context.Context, orgID string, c
return fmt.Errorf("bad response status %v", resp.Status)
}
return nil
})
if err != nil {
return err
}
if err := provider.configStore.Set(ctx, config); err != nil {
return err
}
return nil
return nil
}))
}
func (provider *provider) SetConfig(ctx context.Context, config *alertmanagertypes.Config) error {

View File

@ -109,6 +109,10 @@ func (provider *provider) UpdateChannelByReceiverAndID(ctx context.Context, orgI
return err
}
if err := channel.Update(receiver); err != nil {
return err
}
config, err := provider.configStore.Get(ctx, orgID)
if err != nil {
return err
@ -118,15 +122,9 @@ func (provider *provider) UpdateChannelByReceiverAndID(ctx context.Context, orgI
return err
}
if err := provider.configStore.Set(ctx, config); err != nil {
return err
}
if err := channel.Update(receiver); err != nil {
return err
}
return provider.configStore.UpdateChannel(ctx, orgID, channel, alertmanagertypes.ConfigStoreNoopCallback)
return provider.configStore.UpdateChannel(ctx, orgID, channel, alertmanagertypes.WithCb(func(ctx context.Context) error {
return provider.configStore.Set(ctx, config)
}))
}
func (provider *provider) DeleteChannelByID(ctx context.Context, orgID string, channelID int) error {
@ -144,11 +142,9 @@ func (provider *provider) DeleteChannelByID(ctx context.Context, orgID string, c
return err
}
if err := provider.configStore.Set(ctx, config); err != nil {
return err
}
return provider.configStore.DeleteChannelByID(ctx, orgID, channelID, alertmanagertypes.ConfigStoreNoopCallback)
return provider.configStore.DeleteChannelByID(ctx, orgID, channelID, alertmanagertypes.WithCb(func(ctx context.Context) error {
return provider.configStore.Set(ctx, config)
}))
}
func (provider *provider) CreateChannel(ctx context.Context, orgID string, receiver alertmanagertypes.Receiver) error {
@ -161,12 +157,10 @@ func (provider *provider) CreateChannel(ctx context.Context, orgID string, recei
return err
}
if err := provider.configStore.Set(ctx, config); err != nil {
return err
}
channel := alertmanagertypes.NewChannelFromReceiver(receiver, orgID)
return provider.configStore.CreateChannel(ctx, channel, alertmanagertypes.ConfigStoreNoopCallback)
return provider.configStore.CreateChannel(ctx, channel, alertmanagertypes.WithCb(func(ctx context.Context) error {
return provider.configStore.Set(ctx, config)
}))
}
func (provider *provider) SetConfig(ctx context.Context, config *alertmanagertypes.Config) error {

View File

@ -1,18 +1,72 @@
package sqlstore
import (
"context"
"database/sql"
"github.com/uptrace/bun"
"github.com/uptrace/bun/schema"
"go.signoz.io/signoz/pkg/errors"
"go.signoz.io/signoz/pkg/factory"
)
func NewBunDB(sqldb *sql.DB, dialect schema.Dialect, hooks []SQLStoreHook, opts ...bun.DBOption) *bun.DB {
bunDB := bun.NewDB(sqldb, dialect, opts...)
type transactorKey struct{}
type BunDB struct {
*bun.DB
settings factory.ScopedProviderSettings
}
func NewBunDB(settings factory.ScopedProviderSettings, sqldb *sql.DB, dialect schema.Dialect, hooks []SQLStoreHook, opts ...bun.DBOption) *BunDB {
db := bun.NewDB(sqldb, dialect, opts...)
for _, hook := range hooks {
bunDB.AddQueryHook(hook)
db.AddQueryHook(hook)
}
return bunDB
return &BunDB{db, settings}
}
func (db *BunDB) RunInTxCtx(ctx context.Context, opts *sql.TxOptions, cb func(ctx context.Context) error) error {
tx, ok := txFromContext(ctx)
if ok {
return cb(ctx)
}
// begin transaction
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return errors.Wrapf(err, errors.TypeInternal, errors.CodeInternal, "cannot begin transaction")
}
defer func() {
if err := tx.Rollback(); err != nil {
if err != sql.ErrTxDone {
db.settings.Logger().ErrorContext(ctx, "cannot rollback transaction", "error", err)
}
}
}()
if err := cb(newContextWithTx(ctx, tx)); err != nil {
return err
}
return tx.Commit()
}
func (db *BunDB) BunDBCtx(ctx context.Context) bun.IDB {
tx, ok := txFromContext(ctx)
if !ok {
return db.DB
}
return tx
}
func newContextWithTx(ctx context.Context, tx bun.Tx) context.Context {
return context.WithValue(ctx, transactorKey{}, tx)
}
func txFromContext(ctx context.Context) (bun.Tx, bool) {
tx, ok := ctx.Value(transactorKey{}).(bun.Tx)
return tx, ok
}

View File

@ -16,7 +16,7 @@ import (
type provider struct {
settings factory.ScopedProviderSettings
sqldb *sql.DB
bundb *bun.DB
bundb *sqlstore.BunDB
sqlxdb *sqlx.DB
}
@ -57,13 +57,13 @@ func New(ctx context.Context, providerSettings factory.ProviderSettings, config
return &provider{
settings: settings,
sqldb: sqldb,
bundb: sqlstore.NewBunDB(sqldb, pgdialect.New(), hooks),
bundb: sqlstore.NewBunDB(settings, sqldb, pgdialect.New(), hooks),
sqlxdb: sqlx.NewDb(sqldb, "postgres"),
}, nil
}
func (provider *provider) BunDB() *bun.DB {
return provider.bundb
return provider.bundb.DB
}
func (provider *provider) SQLDB() *sql.DB {
@ -73,3 +73,11 @@ func (provider *provider) SQLDB() *sql.DB {
func (provider *provider) SQLxDB() *sqlx.DB {
return provider.sqlxdb
}
func (provider *provider) BunDBCtx(ctx context.Context) bun.IDB {
return provider.bundb.BunDBCtx(ctx)
}
func (provider *provider) RunInTxCtx(ctx context.Context, opts *sql.TxOptions, cb func(ctx context.Context) error) error {
return provider.bundb.RunInTxCtx(ctx, opts, cb)
}

View File

@ -15,7 +15,7 @@ import (
type provider struct {
settings factory.ScopedProviderSettings
sqldb *sql.DB
bundb *bun.DB
bundb *sqlstore.BunDB
sqlxdb *sqlx.DB
}
@ -47,13 +47,13 @@ func New(ctx context.Context, providerSettings factory.ProviderSettings, config
return &provider{
settings: settings,
sqldb: sqldb,
bundb: sqlstore.NewBunDB(sqldb, sqlitedialect.New(), hooks),
bundb: sqlstore.NewBunDB(settings, sqldb, sqlitedialect.New(), hooks),
sqlxdb: sqlx.NewDb(sqldb, "sqlite3"),
}, nil
}
func (provider *provider) BunDB() *bun.DB {
return provider.bundb
return provider.bundb.DB
}
func (provider *provider) SQLDB() *sql.DB {
@ -63,3 +63,11 @@ func (provider *provider) SQLDB() *sql.DB {
func (provider *provider) SQLxDB() *sqlx.DB {
return provider.sqlxdb
}
func (provider *provider) BunDBCtx(ctx context.Context) bun.IDB {
return provider.bundb.BunDBCtx(ctx)
}
func (provider *provider) RunInTxCtx(ctx context.Context, opts *sql.TxOptions, cb func(ctx context.Context) error) error {
return provider.bundb.RunInTxCtx(ctx, opts, cb)
}

View File

@ -1,20 +1,32 @@
package sqlstore
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"github.com/uptrace/bun"
)
// SQLStore is the interface for the SQLStore.
type SQLStoreTxOptions = sql.TxOptions
type SQLStore interface {
// SQLDB returns the underlying sql.DB.
SQLDB() *sql.DB
// BunDB returns an instance of bun.DB. This is the recommended way to interact with the database.
BunDB() *bun.DB
// SQLxDB returns an instance of sqlx.DB.
// SQLxDB returns an instance of sqlx.DB. This is the legacy ORM used.
SQLxDB() *sqlx.DB
// RunInTxCtx runs the given callback in a transaction. It creates and injects a new context with the transaction.
// If a transaction is present in the context, it will be used.
RunInTxCtx(ctx context.Context, opts *SQLStoreTxOptions, cb func(ctx context.Context) error) error
// BunDBCtx returns an instance of bun.IDB for the given context.
// If a transaction is present in the context, it will be used. Otherwise, the default will be used.
BunDBCtx(ctx context.Context) bun.IDB
}
type SQLStoreHook interface {

View File

@ -1,6 +1,7 @@
package sqlstoretest
import (
"context"
"database/sql"
"fmt"
@ -59,3 +60,11 @@ func (provider *Provider) SQLxDB() *sqlx.DB {
func (provider *Provider) Mock() sqlmock.Sqlmock {
return provider.mock
}
func (provider *Provider) BunDBCtx(ctx context.Context) bun.IDB {
return provider.bunDB
}
func (provider *Provider) RunInTxCtx(ctx context.Context, opts *sql.TxOptions, cb func(ctx context.Context) error) error {
return cb(ctx)
}

View File

@ -128,6 +128,24 @@ func newConfigHash(s string) [16]byte {
return md5.Sum([]byte(s))
}
func (c *Config) CopyWithReset() (*Config, error) {
newConfig, err := NewDefaultConfig(
*c.alertmanagerConfig.Global,
RouteConfig{
GroupByStr: c.alertmanagerConfig.Route.GroupByStr,
GroupInterval: time.Duration(*c.alertmanagerConfig.Route.GroupInterval),
GroupWait: time.Duration(*c.alertmanagerConfig.Route.GroupWait),
RepeatInterval: time.Duration(*c.alertmanagerConfig.Route.RepeatInterval),
},
c.storeableConfig.OrgID,
)
if err != nil {
return nil, err
}
return newConfig, nil
}
func (c *Config) SetGlobalConfig(globalConfig GlobalConfig) {
c.alertmanagerConfig.Global = &globalConfig
c.storeableConfig.Config = string(newRawFromConfig(c.alertmanagerConfig))
@ -201,7 +219,7 @@ func (c *Config) GetReceiver(name string) (Receiver, error) {
}
}
return Receiver{}, errors.Newf(errors.TypeInvalidInput, ErrCodeAlertmanagerChannelNotFound, "channel with name %q not found", name)
return Receiver{}, errors.Newf(errors.TypeNotFound, ErrCodeAlertmanagerChannelNotFound, "channel with name %q not found", name)
}
func (c *Config) UpdateReceiver(receiver config.Receiver) error {
@ -316,9 +334,33 @@ func (c *Config) ReceiverNamesFromRuleID(ruleID string) ([]string, error) {
return receiverNames, nil
}
type storeOptions struct {
Cb func(context.Context) error
}
type StoreOption func(*storeOptions)
func WithCb(cb func(context.Context) error) StoreOption {
return func(o *storeOptions) {
o.Cb = cb
}
}
func NewStoreOptions(opts ...StoreOption) *storeOptions {
o := &storeOptions{
Cb: nil,
}
for _, opt := range opts {
opt(o)
}
return o
}
type ConfigStore interface {
// Set creates or updates a config.
Set(context.Context, *Config) error
Set(context.Context, *Config, ...StoreOption) error
// Get returns the config for the given orgID
Get(context.Context, string) (*Config, error)
@ -327,16 +369,16 @@ type ConfigStore interface {
ListOrgs(context.Context) ([]string, error)
// CreateChannel creates a new channel.
CreateChannel(context.Context, *Channel, func(context.Context) error) error
CreateChannel(context.Context, *Channel, ...StoreOption) error
// GetChannelByID returns the channel for the given id.
GetChannelByID(context.Context, string, int) (*Channel, error)
// UpdateChannel updates a channel.
UpdateChannel(context.Context, string, *Channel, func(context.Context) error) error
UpdateChannel(context.Context, string, *Channel, ...StoreOption) error
// DeleteChannelByID deletes a channel.
DeleteChannelByID(context.Context, string, int, func(context.Context) error) error
DeleteChannelByID(context.Context, string, int, ...StoreOption) error
// ListChannels returns the list of channels.
ListChannels(context.Context, string) ([]*Channel, error)
@ -349,8 +391,6 @@ type ConfigStore interface {
GetMatchers(context.Context, string) (map[string][]string, error)
}
var ConfigStoreNoopCallback = func(ctx context.Context) error { return nil }
// MarshalSecretValue if set to true will expose Secret type
// through the marshal interfaces. We need to store the actual value of the secret
// in the database, so we need to set this to true.

View File

@ -9,6 +9,7 @@ import (
"github.com/prometheus/alertmanager/notify"
"github.com/prometheus/alertmanager/template"
"go.signoz.io/signoz/pkg/errors"
"github.com/prometheus/alertmanager/config"
"github.com/prometheus/alertmanager/config/receiver"
@ -37,16 +38,32 @@ func NewReceiverIntegrations(nc Receiver, tmpl *template.Template, logger *slog.
return receiver.BuildReceiverIntegrations(nc, tmpl, logger)
}
func TestReceiver(ctx context.Context, receiver Receiver, tmpl *template.Template, logger *slog.Logger, alert *Alert) error {
func TestReceiver(ctx context.Context, receiver Receiver, config *Config, tmpl *template.Template, logger *slog.Logger, alert *Alert) error {
ctx = notify.WithGroupKey(ctx, fmt.Sprintf("%s-%s-%d", receiver.Name, alert.Labels.Fingerprint(), time.Now().Unix()))
ctx = notify.WithGroupLabels(ctx, alert.Labels)
ctx = notify.WithReceiverName(ctx, receiver.Name)
// We need to create a new config with the same global and route config but empty receivers and routes
// This is so that we can call CreateReceiver without worrying about the existing receivers and routes.
// CreateReceiver will ensure that any defaults (such as http config in the case of slack) are set. Otherwise the integration will panic.
testConfig, err := config.CopyWithReset()
if err != nil {
return err
}
if err := testConfig.CreateReceiver(receiver); err != nil {
return err
}
integrations, err := NewReceiverIntegrations(receiver, tmpl, logger)
if err != nil {
return err
}
if len(integrations) == 0 {
return errors.Newf(errors.TypeNotFound, errors.CodeNotFound, "no integrations found for receiver %s", receiver.Name)
}
if _, err = integrations[0].Notify(ctx, alert); err != nil {
return err
}