fix: move pat and org domains towards postgres multitenancy (#7337)

* fix: inital commit for pat

* fix: add migration file

* fix: add domain changes

* fix: minor fixes

* fix: update migration

* fix: update migration

* fix: update pat and old migration

* fix: move domain and sso type to ee
This commit is contained in:
Nityananda Gohain 2025-03-20 13:59:52 +05:30 committed by GitHub
parent 0320285a25
commit 9d8e46e5b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 489 additions and 430 deletions

View File

@ -2,26 +2,69 @@ package middleware
import ( import (
"net/http" "net/http"
"time"
"github.com/uptrace/bun"
"go.signoz.io/signoz/pkg/types"
"go.signoz.io/signoz/pkg/types/authtypes" "go.signoz.io/signoz/pkg/types/authtypes"
"go.uber.org/zap"
) )
type Pat struct { type Pat struct {
db *bun.DB
uuid *authtypes.UUID uuid *authtypes.UUID
headers []string headers []string
} }
func NewPat(headers []string) *Pat { func NewPat(db *bun.DB, headers []string) *Pat {
return &Pat{uuid: authtypes.NewUUID(), headers: headers} return &Pat{db: db, uuid: authtypes.NewUUID(), headers: headers}
} }
func (p *Pat) Wrap(next http.Handler) http.Handler { func (p *Pat) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var values []string var values []string
var patToken string
var pat types.StorablePersonalAccessToken
var updateLastUsed bool
for _, header := range p.headers { for _, header := range p.headers {
values = append(values, r.Header.Get(header)) values = append(values, r.Header.Get(header))
if header == "SIGNOZ-API-KEY" {
patToken = values[0]
err := p.db.NewSelect().Model(&pat).Where("token = ?", patToken).Scan(r.Context())
if err != nil {
next.ServeHTTP(w, r)
return
} }
if pat.ExpiresAt < time.Now().Unix() && pat.ExpiresAt != 0 {
next.ServeHTTP(w, r)
return
}
// get user from db
user := types.User{}
err = p.db.NewSelect().Model(&user).Where("id = ?", pat.UserID).Scan(r.Context())
if err != nil {
next.ServeHTTP(w, r)
return
}
jwt := authtypes.Claims{
UserID: user.ID,
GroupID: user.GroupID,
Email: user.Email,
OrgID: user.OrgID,
}
ctx := authtypes.NewContextWithClaims(r.Context(), jwt)
r = r.WithContext(ctx)
// Mark to update last used since SIGNOZ-API-KEY is present and successful
updateLastUsed = true
}
}
ctx, err := p.uuid.ContextFromRequest(r.Context(), values...) ctx, err := p.uuid.ContextFromRequest(r.Context(), values...)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -31,6 +74,16 @@ func (p *Pat) Wrap(next http.Handler) http.Handler {
r = r.WithContext(ctx) r = r.WithContext(ctx)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
// update last used only if SIGNOZ-API-KEY was present and successful
if updateLastUsed {
pat.LastUsed = time.Now().Unix()
_, err = p.db.NewUpdate().Model(&pat).Column("last_used").Where("token = ?", patToken).Where("revoked = false").Exec(r.Context())
if err != nil {
zap.L().Error("Failed to update PAT last used in db, err: %v", zap.Error(err))
}
}
}) })
} }

View File

@ -118,7 +118,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationPAT(ctx context.Context, orgId
return "", apiErr return "", apiErr
} }
allPats, err := ah.AppDao().ListPATs(ctx) allPats, err := ah.AppDao().ListPATs(ctx, orgId)
if err != nil { if err != nil {
return "", basemodel.InternalError(fmt.Errorf( return "", basemodel.InternalError(fmt.Errorf(
"couldn't list PATs: %w", err, "couldn't list PATs: %w", err,
@ -136,15 +136,19 @@ func (ah *APIHandler) getOrCreateCloudIntegrationPAT(ctx context.Context, orgId
) )
newPAT := model.PAT{ newPAT := model.PAT{
StorablePersonalAccessToken: types.StorablePersonalAccessToken{
Token: generatePATToken(), Token: generatePATToken(),
UserID: integrationUser.ID, UserID: integrationUser.ID,
Name: integrationPATName, Name: integrationPATName,
Role: baseconstants.ViewerGroup, Role: baseconstants.ViewerGroup,
ExpiresAt: 0, ExpiresAt: 0,
CreatedAt: time.Now().Unix(), TimeAuditable: types.TimeAuditable{
UpdatedAt: time.Now().Unix(), CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
},
} }
integrationPAT, err := ah.AppDao().CreatePAT(ctx, newPAT) integrationPAT, err := ah.AppDao().CreatePAT(ctx, orgId, newPAT)
if err != nil { if err != nil {
return "", basemodel.InternalError(fmt.Errorf( return "", basemodel.InternalError(fmt.Errorf(
"couldn't create cloud integration PAT: %w", err, "couldn't create cloud integration PAT: %w", err,

View File

@ -9,6 +9,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"go.signoz.io/signoz/ee/query-service/model" "go.signoz.io/signoz/ee/query-service/model"
"go.signoz.io/signoz/ee/types"
) )
func (ah *APIHandler) listDomainsByOrg(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) listDomainsByOrg(w http.ResponseWriter, r *http.Request) {
@ -24,7 +25,7 @@ func (ah *APIHandler) listDomainsByOrg(w http.ResponseWriter, r *http.Request) {
func (ah *APIHandler) postDomain(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) postDomain(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() ctx := context.Background()
req := model.OrgDomain{} req := types.GettableOrgDomain{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
RespondError(w, model.BadRequest(err), nil) RespondError(w, model.BadRequest(err), nil)
@ -54,12 +55,12 @@ func (ah *APIHandler) putDomain(w http.ResponseWriter, r *http.Request) {
return return
} }
req := model.OrgDomain{Id: domainId} req := types.GettableOrgDomain{StorableOrgDomain: types.StorableOrgDomain{ID: domainId}}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
RespondError(w, model.BadRequest(err), nil) RespondError(w, model.BadRequest(err), nil)
return return
} }
req.Id = domainId req.ID = domainId
if err := req.Valid(nil); err != nil { if err := req.Valid(nil); err != nil {
RespondError(w, model.BadRequest(err), nil) RespondError(w, model.BadRequest(err), nil)
} }

View File

@ -14,6 +14,7 @@ import (
"go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/auth"
baseconstants "go.signoz.io/signoz/pkg/query-service/constants" baseconstants "go.signoz.io/signoz/pkg/query-service/constants"
basemodel "go.signoz.io/signoz/pkg/query-service/model" basemodel "go.signoz.io/signoz/pkg/query-service/model"
"go.signoz.io/signoz/pkg/types"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -43,9 +44,11 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
return return
} }
pat := model.PAT{ pat := model.PAT{
StorablePersonalAccessToken: types.StorablePersonalAccessToken{
Name: req.Name, Name: req.Name,
Role: req.Role, Role: req.Role,
ExpiresAt: req.ExpiresInDays, ExpiresAt: req.ExpiresInDays,
},
} }
err = validatePATRequest(pat) err = validatePATRequest(pat)
if err != nil { if err != nil {
@ -55,8 +58,8 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
// All the PATs are associated with the user creating the PAT. // All the PATs are associated with the user creating the PAT.
pat.UserID = user.ID pat.UserID = user.ID
pat.CreatedAt = time.Now().Unix() pat.CreatedAt = time.Now()
pat.UpdatedAt = time.Now().Unix() pat.UpdatedAt = time.Now()
pat.LastUsed = 0 pat.LastUsed = 0
pat.Token = generatePATToken() pat.Token = generatePATToken()
@ -67,7 +70,7 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
zap.L().Info("Got Create PAT request", zap.Any("pat", pat)) zap.L().Info("Got Create PAT request", zap.Any("pat", pat))
var apierr basemodel.BaseApiError var apierr basemodel.BaseApiError
if pat, apierr = ah.AppDao().CreatePAT(ctx, pat); apierr != nil { if pat, apierr = ah.AppDao().CreatePAT(ctx, user.OrgID, pat); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
@ -114,10 +117,10 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
req.UpdatedByUserID = user.ID req.UpdatedByUserID = user.ID
id := mux.Vars(r)["id"] id := mux.Vars(r)["id"]
req.UpdatedAt = time.Now().Unix() req.UpdatedAt = time.Now()
zap.L().Info("Got Update PAT request", zap.Any("pat", req)) zap.L().Info("Got Update PAT request", zap.Any("pat", req))
var apierr basemodel.BaseApiError var apierr basemodel.BaseApiError
if apierr = ah.AppDao().UpdatePAT(ctx, req, id); apierr != nil { if apierr = ah.AppDao().UpdatePAT(ctx, user.OrgID, req, id); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
@ -136,7 +139,7 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
return return
} }
zap.L().Info("Get PATs for user", zap.String("user_id", user.ID)) zap.L().Info("Get PATs for user", zap.String("user_id", user.ID))
pats, apierr := ah.AppDao().ListPATs(ctx) pats, apierr := ah.AppDao().ListPATs(ctx, user.OrgID)
if apierr != nil { if apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
@ -157,7 +160,7 @@ func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) {
} }
zap.L().Info("Revoke PAT with id", zap.String("id", id)) zap.L().Info("Revoke PAT with id", zap.String("id", id))
if apierr := ah.AppDao().RevokePAT(ctx, id, user.ID); apierr != nil { if apierr := ah.AppDao().RevokePAT(ctx, user.OrgID, id, user.ID); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }

View File

@ -17,7 +17,6 @@ import (
eemiddleware "go.signoz.io/signoz/ee/http/middleware" eemiddleware "go.signoz.io/signoz/ee/http/middleware"
"go.signoz.io/signoz/ee/query-service/app/api" "go.signoz.io/signoz/ee/query-service/app/api"
"go.signoz.io/signoz/ee/query-service/app/db" "go.signoz.io/signoz/ee/query-service/app/db"
"go.signoz.io/signoz/ee/query-service/auth"
"go.signoz.io/signoz/ee/query-service/constants" "go.signoz.io/signoz/ee/query-service/constants"
"go.signoz.io/signoz/ee/query-service/dao" "go.signoz.io/signoz/ee/query-service/dao"
"go.signoz.io/signoz/ee/query-service/integrations/gateway" "go.signoz.io/signoz/ee/query-service/integrations/gateway"
@ -25,6 +24,7 @@ import (
"go.signoz.io/signoz/ee/query-service/rules" "go.signoz.io/signoz/ee/query-service/rules"
"go.signoz.io/signoz/pkg/alertmanager" "go.signoz.io/signoz/pkg/alertmanager"
"go.signoz.io/signoz/pkg/http/middleware" "go.signoz.io/signoz/pkg/http/middleware"
"go.signoz.io/signoz/pkg/query-service/auth"
"go.signoz.io/signoz/pkg/signoz" "go.signoz.io/signoz/pkg/signoz"
"go.signoz.io/signoz/pkg/sqlstore" "go.signoz.io/signoz/pkg/sqlstore"
"go.signoz.io/signoz/pkg/types" "go.signoz.io/signoz/pkg/types"
@ -317,7 +317,7 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,
r := baseapp.NewRouter() r := baseapp.NewRouter()
r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
r.Use(eemiddleware.NewPat([]string{"SIGNOZ-API-KEY"}).Wrap) r.Use(eemiddleware.NewPat(s.serverOptions.SigNoz.SQLStore.BunDB(), []string{"SIGNOZ-API-KEY"}).Wrap)
r.Use(middleware.NewTimeout(zap.L(), r.Use(middleware.NewTimeout(zap.L(),
s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes, s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes,
s.serverOptions.Config.APIServer.Timeout.Default, s.serverOptions.Config.APIServer.Timeout.Default,
@ -350,7 +350,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h
// add auth middleware // add auth middleware
getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) { getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) {
user, err := auth.GetUserFromRequestContext(ctx, apiHandler) user, err := auth.GetUserFromReqContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -365,7 +365,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h
am := baseapp.NewAuthMiddleware(getUserFromRequest) am := baseapp.NewAuthMiddleware(getUserFromRequest)
r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
r.Use(eemiddleware.NewPat([]string{"SIGNOZ-API-KEY"}).Wrap) r.Use(eemiddleware.NewPat(s.serverOptions.SigNoz.SQLStore.BunDB(), []string{"SIGNOZ-API-KEY"}).Wrap)
r.Use(middleware.NewTimeout(zap.L(), r.Use(middleware.NewTimeout(zap.L(),
s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes, s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes,
s.serverOptions.Config.APIServer.Timeout.Default, s.serverOptions.Config.APIServer.Timeout.Default,

View File

@ -1,56 +0,0 @@
package auth
import (
"context"
"fmt"
"time"
"go.signoz.io/signoz/ee/query-service/app/api"
baseauth "go.signoz.io/signoz/pkg/query-service/auth"
"go.signoz.io/signoz/pkg/query-service/telemetry"
"go.signoz.io/signoz/pkg/types"
"go.signoz.io/signoz/pkg/types/authtypes"
"go.uber.org/zap"
)
func GetUserFromRequestContext(ctx context.Context, apiHandler *api.APIHandler) (*types.GettableUser, error) {
patToken, ok := authtypes.UUIDFromContext(ctx)
if ok && patToken != "" {
zap.L().Debug("Received a non-zero length PAT token")
ctx := context.Background()
dao := apiHandler.AppDao()
pat, err := dao.GetPAT(ctx, patToken)
if err == nil && pat != nil {
zap.L().Debug("Found valid PAT: ", zap.Any("pat", pat))
if pat.ExpiresAt < time.Now().Unix() && pat.ExpiresAt != 0 {
zap.L().Info("PAT has expired: ", zap.Any("pat", pat))
return nil, fmt.Errorf("PAT has expired")
}
group, apiErr := dao.GetGroupByName(ctx, pat.Role)
if apiErr != nil {
zap.L().Error("Error while getting group for PAT: ", zap.Any("apiErr", apiErr))
return nil, apiErr
}
user, err := dao.GetUser(ctx, pat.UserID)
if err != nil {
zap.L().Error("Error while getting user for PAT: ", zap.Error(err))
return nil, err
}
telemetry.GetInstance().SetPatTokenUser()
dao.UpdatePATLastUsed(ctx, patToken, time.Now().Unix())
user.User.GroupID = group.ID
user.User.ID = pat.Id
return &types.GettableUser{
User: user.User,
Role: pat.Role,
}, nil
}
if err != nil {
zap.L().Error("Error while getting user for PAT: ", zap.Error(err))
return nil, err
}
}
return baseauth.GetUserFromReqContext(ctx)
}

View File

@ -5,12 +5,13 @@ import (
"net/url" "net/url"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jmoiron/sqlx" "github.com/uptrace/bun"
"go.signoz.io/signoz/ee/query-service/model" "go.signoz.io/signoz/ee/query-service/model"
"go.signoz.io/signoz/ee/types"
basedao "go.signoz.io/signoz/pkg/query-service/dao" basedao "go.signoz.io/signoz/pkg/query-service/dao"
baseint "go.signoz.io/signoz/pkg/query-service/interfaces" baseint "go.signoz.io/signoz/pkg/query-service/interfaces"
basemodel "go.signoz.io/signoz/pkg/query-service/model" basemodel "go.signoz.io/signoz/pkg/query-service/model"
"go.signoz.io/signoz/pkg/types" ossTypes "go.signoz.io/signoz/pkg/types"
"go.signoz.io/signoz/pkg/types/authtypes" "go.signoz.io/signoz/pkg/types/authtypes"
) )
@ -20,27 +21,26 @@ type ModelDao interface {
// SetFlagProvider sets the feature lookup provider // SetFlagProvider sets the feature lookup provider
SetFlagProvider(flags baseint.FeatureLookup) SetFlagProvider(flags baseint.FeatureLookup)
DB() *sqlx.DB DB() *bun.DB
// auth methods // auth methods
CanUsePassword(ctx context.Context, email string) (bool, basemodel.BaseApiError) CanUsePassword(ctx context.Context, email string) (bool, basemodel.BaseApiError)
PrepareSsoRedirect(ctx context.Context, redirectUri, email string, jwt *authtypes.JWT) (redirectURL string, apierr basemodel.BaseApiError) PrepareSsoRedirect(ctx context.Context, redirectUri, email string, jwt *authtypes.JWT) (redirectURL string, apierr basemodel.BaseApiError)
GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*model.OrgDomain, error) GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*types.GettableOrgDomain, error)
// org domain (auth domains) CRUD ops // org domain (auth domains) CRUD ops
ListDomains(ctx context.Context, orgId string) ([]model.OrgDomain, basemodel.BaseApiError) ListDomains(ctx context.Context, orgId string) ([]types.GettableOrgDomain, basemodel.BaseApiError)
GetDomain(ctx context.Context, id uuid.UUID) (*model.OrgDomain, basemodel.BaseApiError) GetDomain(ctx context.Context, id uuid.UUID) (*types.GettableOrgDomain, basemodel.BaseApiError)
CreateDomain(ctx context.Context, d *model.OrgDomain) basemodel.BaseApiError CreateDomain(ctx context.Context, d *types.GettableOrgDomain) basemodel.BaseApiError
UpdateDomain(ctx context.Context, domain *model.OrgDomain) basemodel.BaseApiError UpdateDomain(ctx context.Context, domain *types.GettableOrgDomain) basemodel.BaseApiError
DeleteDomain(ctx context.Context, id uuid.UUID) basemodel.BaseApiError DeleteDomain(ctx context.Context, id uuid.UUID) basemodel.BaseApiError
GetDomainByEmail(ctx context.Context, email string) (*model.OrgDomain, basemodel.BaseApiError) GetDomainByEmail(ctx context.Context, email string) (*types.GettableOrgDomain, basemodel.BaseApiError)
CreatePAT(ctx context.Context, p model.PAT) (model.PAT, basemodel.BaseApiError) CreatePAT(ctx context.Context, orgID string, p model.PAT) (model.PAT, basemodel.BaseApiError)
UpdatePAT(ctx context.Context, p model.PAT, id string) basemodel.BaseApiError UpdatePAT(ctx context.Context, orgID string, p model.PAT, id string) basemodel.BaseApiError
GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError) GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError)
UpdatePATLastUsed(ctx context.Context, pat string, lastUsed int64) basemodel.BaseApiError GetPATByID(ctx context.Context, orgID string, id string) (*model.PAT, basemodel.BaseApiError)
GetPATByID(ctx context.Context, id string) (*model.PAT, basemodel.BaseApiError) GetUserByPAT(ctx context.Context, orgID string, token string) (*ossTypes.GettableUser, basemodel.BaseApiError)
GetUserByPAT(ctx context.Context, token string) (*types.GettableUser, basemodel.BaseApiError) ListPATs(ctx context.Context, orgID string) ([]model.PAT, basemodel.BaseApiError)
ListPATs(ctx context.Context) ([]model.PAT, basemodel.BaseApiError) RevokePAT(ctx context.Context, orgID string, id string, userID string) basemodel.BaseApiError
RevokePAT(ctx context.Context, id string, userID string) basemodel.BaseApiError
} }

View File

@ -53,7 +53,7 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (
}, },
ProfilePictureURL: "", // Currently unused ProfilePictureURL: "", // Currently unused
GroupID: group.ID, GroupID: group.ID,
OrgID: domain.OrgId, OrgID: domain.OrgID,
} }
user, apiErr = m.CreateUser(ctx, user, false) user, apiErr = m.CreateUser(ctx, user, false)

View File

@ -11,30 +11,21 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"go.signoz.io/signoz/ee/query-service/model" "go.signoz.io/signoz/ee/query-service/model"
"go.signoz.io/signoz/ee/types"
basemodel "go.signoz.io/signoz/pkg/query-service/model" basemodel "go.signoz.io/signoz/pkg/query-service/model"
ossTypes "go.signoz.io/signoz/pkg/types"
"go.uber.org/zap" "go.uber.org/zap"
) )
// StoredDomain represents stored database record for org domain
type StoredDomain struct {
Id uuid.UUID `db:"id"`
Name string `db:"name"`
OrgId string `db:"org_id"`
Data string `db:"data"`
CreatedAt int64 `db:"created_at"`
UpdatedAt int64 `db:"updated_at"`
}
// GetDomainFromSsoResponse uses relay state received from IdP to fetch // GetDomainFromSsoResponse uses relay state received from IdP to fetch
// user domain. The domain is further used to process validity of the response. // user domain. The domain is further used to process validity of the response.
// when sending login request to IdP we send relay state as URL (site url) // when sending login request to IdP we send relay state as URL (site url)
// with domainId or domainName as query parameter. // with domainId or domainName as query parameter.
func (m *modelDao) GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*model.OrgDomain, error) { func (m *modelDao) GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*types.GettableOrgDomain, error) {
// derive domain id from relay state now // derive domain id from relay state now
var domainIdStr string var domainIdStr string
var domainNameStr string var domainNameStr string
var domain *model.OrgDomain var domain *types.GettableOrgDomain
for k, v := range relayState.Query() { for k, v := range relayState.Query() {
if k == "domainId" && len(v) > 0 { if k == "domainId" && len(v) > 0 {
@ -76,10 +67,14 @@ func (m *modelDao) GetDomainFromSsoResponse(ctx context.Context, relayState *url
} }
// GetDomainByName returns org domain for a given domain name // GetDomainByName returns org domain for a given domain name
func (m *modelDao) GetDomainByName(ctx context.Context, name string) (*model.OrgDomain, basemodel.BaseApiError) { func (m *modelDao) GetDomainByName(ctx context.Context, name string) (*types.GettableOrgDomain, basemodel.BaseApiError) {
stored := StoredDomain{} stored := types.StorableOrgDomain{}
err := m.DB().Get(&stored, `SELECT * FROM org_domains WHERE name=$1 LIMIT 1`, name) err := m.DB().NewSelect().
Model(&stored).
Where("name = ?", name).
Limit(1).
Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -88,7 +83,7 @@ func (m *modelDao) GetDomainByName(ctx context.Context, name string) (*model.Org
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
domain := &model.OrgDomain{Id: stored.Id, Name: stored.Name, OrgId: stored.OrgId} domain := &types.GettableOrgDomain{StorableOrgDomain: stored}
if err := domain.LoadConfig(stored.Data); err != nil { if err := domain.LoadConfig(stored.Data); err != nil {
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
@ -96,10 +91,14 @@ func (m *modelDao) GetDomainByName(ctx context.Context, name string) (*model.Org
} }
// GetDomain returns org domain for a given domain id // GetDomain returns org domain for a given domain id
func (m *modelDao) GetDomain(ctx context.Context, id uuid.UUID) (*model.OrgDomain, basemodel.BaseApiError) { func (m *modelDao) GetDomain(ctx context.Context, id uuid.UUID) (*types.GettableOrgDomain, basemodel.BaseApiError) {
stored := StoredDomain{} stored := types.StorableOrgDomain{}
err := m.DB().Get(&stored, `SELECT * FROM org_domains WHERE id=$1 LIMIT 1`, id) err := m.DB().NewSelect().
Model(&stored).
Where("id = ?", id).
Limit(1).
Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -108,7 +107,7 @@ func (m *modelDao) GetDomain(ctx context.Context, id uuid.UUID) (*model.OrgDomai
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
domain := &model.OrgDomain{Id: stored.Id, Name: stored.Name, OrgId: stored.OrgId} domain := &types.GettableOrgDomain{StorableOrgDomain: stored}
if err := domain.LoadConfig(stored.Data); err != nil { if err := domain.LoadConfig(stored.Data); err != nil {
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
@ -116,21 +115,24 @@ func (m *modelDao) GetDomain(ctx context.Context, id uuid.UUID) (*model.OrgDomai
} }
// ListDomains gets the list of auth domains by org id // ListDomains gets the list of auth domains by org id
func (m *modelDao) ListDomains(ctx context.Context, orgId string) ([]model.OrgDomain, basemodel.BaseApiError) { func (m *modelDao) ListDomains(ctx context.Context, orgId string) ([]types.GettableOrgDomain, basemodel.BaseApiError) {
domains := []model.OrgDomain{} domains := []types.GettableOrgDomain{}
stored := []StoredDomain{} stored := []types.StorableOrgDomain{}
err := m.DB().SelectContext(ctx, &stored, `SELECT * FROM org_domains WHERE org_id=$1`, orgId) err := m.DB().NewSelect().
Model(&stored).
Where("org_id = ?", orgId).
Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return []model.OrgDomain{}, nil return domains, nil
} }
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
for _, s := range stored { for _, s := range stored {
domain := model.OrgDomain{Id: s.Id, Name: s.Name, OrgId: s.OrgId} domain := types.GettableOrgDomain{StorableOrgDomain: s}
if err := domain.LoadConfig(s.Data); err != nil { if err := domain.LoadConfig(s.Data); err != nil {
zap.L().Error("ListDomains() failed", zap.Error(err)) zap.L().Error("ListDomains() failed", zap.Error(err))
} }
@ -141,14 +143,14 @@ func (m *modelDao) ListDomains(ctx context.Context, orgId string) ([]model.OrgDo
} }
// CreateDomain creates a new auth domain // CreateDomain creates a new auth domain
func (m *modelDao) CreateDomain(ctx context.Context, domain *model.OrgDomain) basemodel.BaseApiError { func (m *modelDao) CreateDomain(ctx context.Context, domain *types.GettableOrgDomain) basemodel.BaseApiError {
if domain.Id == uuid.Nil { if domain.ID == uuid.Nil {
domain.Id = uuid.New() domain.ID = uuid.New()
} }
if domain.OrgId == "" || domain.Name == "" { if domain.OrgID == "" || domain.Name == "" {
return model.BadRequest(fmt.Errorf("domain creation failed, missing fields: OrgId, Name ")) return model.BadRequest(fmt.Errorf("domain creation failed, missing fields: OrgID, Name "))
} }
configJson, err := json.Marshal(domain) configJson, err := json.Marshal(domain)
@ -157,14 +159,17 @@ func (m *modelDao) CreateDomain(ctx context.Context, domain *model.OrgDomain) ba
return model.InternalError(fmt.Errorf("domain creation failed")) return model.InternalError(fmt.Errorf("domain creation failed"))
} }
_, err = m.DB().ExecContext(ctx, storableDomain := types.StorableOrgDomain{
"INSERT INTO org_domains (id, name, org_id, data, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6)", ID: domain.ID,
domain.Id, Name: domain.Name,
domain.Name, OrgID: domain.OrgID,
domain.OrgId, Data: string(configJson),
configJson, TimeAuditable: ossTypes.TimeAuditable{CreatedAt: time.Now(), UpdatedAt: time.Now()},
time.Now().Unix(), }
time.Now().Unix())
_, err = m.DB().NewInsert().
Model(&storableDomain).
Exec(ctx)
if err != nil { if err != nil {
zap.L().Error("failed to insert domain in db", zap.Error(err)) zap.L().Error("failed to insert domain in db", zap.Error(err))
@ -175,9 +180,9 @@ func (m *modelDao) CreateDomain(ctx context.Context, domain *model.OrgDomain) ba
} }
// UpdateDomain updates stored config params for a domain // UpdateDomain updates stored config params for a domain
func (m *modelDao) UpdateDomain(ctx context.Context, domain *model.OrgDomain) basemodel.BaseApiError { func (m *modelDao) UpdateDomain(ctx context.Context, domain *types.GettableOrgDomain) basemodel.BaseApiError {
if domain.Id == uuid.Nil { if domain.ID == uuid.Nil {
zap.L().Error("domain update failed", zap.Error(fmt.Errorf("OrgDomain.Id is null"))) zap.L().Error("domain update failed", zap.Error(fmt.Errorf("OrgDomain.Id is null")))
return model.InternalError(fmt.Errorf("domain update failed")) return model.InternalError(fmt.Errorf("domain update failed"))
} }
@ -188,11 +193,19 @@ func (m *modelDao) UpdateDomain(ctx context.Context, domain *model.OrgDomain) ba
return model.InternalError(fmt.Errorf("domain update failed")) return model.InternalError(fmt.Errorf("domain update failed"))
} }
_, err = m.DB().ExecContext(ctx, storableDomain := &types.StorableOrgDomain{
"UPDATE org_domains SET data = $1, updated_at = $2 WHERE id = $3", ID: domain.ID,
configJson, Name: domain.Name,
time.Now().Unix(), OrgID: domain.OrgID,
domain.Id) Data: string(configJson),
TimeAuditable: ossTypes.TimeAuditable{UpdatedAt: time.Now()},
}
_, err = m.DB().NewUpdate().
Model(storableDomain).
Column("data", "updated_at").
WherePK().
Exec(ctx)
if err != nil { if err != nil {
zap.L().Error("domain update failed", zap.Error(err)) zap.L().Error("domain update failed", zap.Error(err))
@ -210,9 +223,11 @@ func (m *modelDao) DeleteDomain(ctx context.Context, id uuid.UUID) basemodel.Bas
return model.InternalError(fmt.Errorf("domain delete failed")) return model.InternalError(fmt.Errorf("domain delete failed"))
} }
_, err := m.DB().ExecContext(ctx, storableDomain := &types.StorableOrgDomain{ID: id}
"DELETE FROM org_domains WHERE id = $1", _, err := m.DB().NewDelete().
id) Model(storableDomain).
WherePK().
Exec(ctx)
if err != nil { if err != nil {
zap.L().Error("domain delete failed", zap.Error(err)) zap.L().Error("domain delete failed", zap.Error(err))
@ -222,7 +237,7 @@ func (m *modelDao) DeleteDomain(ctx context.Context, id uuid.UUID) basemodel.Bas
return nil return nil
} }
func (m *modelDao) GetDomainByEmail(ctx context.Context, email string) (*model.OrgDomain, basemodel.BaseApiError) { func (m *modelDao) GetDomainByEmail(ctx context.Context, email string) (*types.GettableOrgDomain, basemodel.BaseApiError) {
if email == "" { if email == "" {
return nil, model.BadRequest(fmt.Errorf("could not find auth domain, missing fields: email ")) return nil, model.BadRequest(fmt.Errorf("could not find auth domain, missing fields: email "))
@ -235,8 +250,12 @@ func (m *modelDao) GetDomainByEmail(ctx context.Context, email string) (*model.O
parsedDomain := components[1] parsedDomain := components[1]
stored := StoredDomain{} stored := types.StorableOrgDomain{}
err := m.DB().Get(&stored, `SELECT * FROM org_domains WHERE name=$1 LIMIT 1`, parsedDomain) err := m.DB().NewSelect().
Model(&stored).
Where("name = ?", parsedDomain).
Limit(1).
Scan(ctx)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -245,7 +264,7 @@ func (m *modelDao) GetDomainByEmail(ctx context.Context, email string) (*model.O
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
domain := &model.OrgDomain{Id: stored.Id, Name: stored.Name, OrgId: stored.OrgId} domain := &types.GettableOrgDomain{StorableOrgDomain: stored}
if err := domain.LoadConfig(stored.Data); err != nil { if err := domain.LoadConfig(stored.Data); err != nil {
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }

View File

@ -3,7 +3,7 @@ package sqlite
import ( import (
"fmt" "fmt"
"github.com/jmoiron/sqlx" "github.com/uptrace/bun"
basedao "go.signoz.io/signoz/pkg/query-service/dao" basedao "go.signoz.io/signoz/pkg/query-service/dao"
basedsql "go.signoz.io/signoz/pkg/query-service/dao/sqlite" basedsql "go.signoz.io/signoz/pkg/query-service/dao/sqlite"
baseint "go.signoz.io/signoz/pkg/query-service/interfaces" baseint "go.signoz.io/signoz/pkg/query-service/interfaces"
@ -41,6 +41,6 @@ func InitDB(sqlStore sqlstore.SQLStore) (*modelDao, error) {
return m, nil return m, nil
} }
func (m *modelDao) DB() *sqlx.DB { func (m *modelDao) DB() *bun.DB {
return m.ModelDaoSqlite.DB() return m.ModelDaoSqlite.DB()
} }

View File

@ -3,7 +3,6 @@ package sqlite
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"time" "time"
"go.signoz.io/signoz/ee/query-service/model" "go.signoz.io/signoz/ee/query-service/model"
@ -12,30 +11,17 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
func (m *modelDao) CreatePAT(ctx context.Context, p model.PAT) (model.PAT, basemodel.BaseApiError) { func (m *modelDao) CreatePAT(ctx context.Context, orgID string, p model.PAT) (model.PAT, basemodel.BaseApiError) {
result, err := m.DB().ExecContext(ctx, p.StorablePersonalAccessToken.OrgID = orgID
"INSERT INTO personal_access_tokens (user_id, token, role, name, created_at, expires_at, updated_at, updated_by_user_id, last_used, revoked) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", _, err := m.DB().NewInsert().
p.UserID, Model(&p.StorablePersonalAccessToken).
p.Token, Returning("id").
p.Role, Exec(ctx)
p.Name,
p.CreatedAt,
p.ExpiresAt,
p.UpdatedAt,
p.UpdatedByUserID,
p.LastUsed,
p.Revoked,
)
if err != nil { if err != nil {
zap.L().Error("Failed to insert PAT in db, err: %v", zap.Error(err)) zap.L().Error("Failed to insert PAT in db, err: %v", zap.Error(err))
return model.PAT{}, model.InternalError(fmt.Errorf("PAT insertion failed")) return model.PAT{}, model.InternalError(fmt.Errorf("PAT insertion failed"))
} }
id, err := result.LastInsertId()
if err != nil {
zap.L().Error("Failed to get last inserted id, err: %v", zap.Error(err))
return model.PAT{}, model.InternalError(fmt.Errorf("PAT insertion failed"))
}
p.Id = strconv.Itoa(int(id))
createdByUser, _ := m.GetUser(ctx, p.UserID) createdByUser, _ := m.GetUser(ctx, p.UserID)
if createdByUser == nil { if createdByUser == nil {
p.CreatedByUser = model.User{ p.CreatedByUser = model.User{
@ -54,14 +40,14 @@ func (m *modelDao) CreatePAT(ctx context.Context, p model.PAT) (model.PAT, basem
return p, nil return p, nil
} }
func (m *modelDao) UpdatePAT(ctx context.Context, p model.PAT, id string) basemodel.BaseApiError { func (m *modelDao) UpdatePAT(ctx context.Context, orgID string, p model.PAT, id string) basemodel.BaseApiError {
_, err := m.DB().ExecContext(ctx, _, err := m.DB().NewUpdate().
"UPDATE personal_access_tokens SET role=$1, name=$2, updated_at=$3, updated_by_user_id=$4 WHERE id=$5 and revoked=false;", Model(&p.StorablePersonalAccessToken).
p.Role, Column("role", "name", "updated_at", "updated_by_user_id").
p.Name, Where("id = ?", id).
p.UpdatedAt, Where("org_id = ?", orgID).
p.UpdatedByUserID, Where("revoked = false").
id) Exec(ctx)
if err != nil { if err != nil {
zap.L().Error("Failed to update PAT in db, err: %v", zap.Error(err)) zap.L().Error("Failed to update PAT in db, err: %v", zap.Error(err))
return model.InternalError(fmt.Errorf("PAT update failed")) return model.InternalError(fmt.Errorf("PAT update failed"))
@ -69,33 +55,32 @@ func (m *modelDao) UpdatePAT(ctx context.Context, p model.PAT, id string) basemo
return nil return nil
} }
func (m *modelDao) UpdatePATLastUsed(ctx context.Context, token string, lastUsed int64) basemodel.BaseApiError { func (m *modelDao) ListPATs(ctx context.Context, orgID string) ([]model.PAT, basemodel.BaseApiError) {
_, err := m.DB().ExecContext(ctx, pats := []types.StorablePersonalAccessToken{}
"UPDATE personal_access_tokens SET last_used=$1 WHERE token=$2 and revoked=false;",
lastUsed,
token)
if err != nil {
zap.L().Error("Failed to update PAT last used in db, err: %v", zap.Error(err))
return model.InternalError(fmt.Errorf("PAT last used update failed"))
}
return nil
}
func (m *modelDao) ListPATs(ctx context.Context) ([]model.PAT, basemodel.BaseApiError) { if err := m.DB().NewSelect().
pats := []model.PAT{} Model(&pats).
Where("revoked = false").
if err := m.DB().Select(&pats, "SELECT * FROM personal_access_tokens WHERE revoked=false ORDER by updated_at DESC;"); err != nil { Where("org_id = ?", orgID).
Order("updated_at DESC").
Scan(ctx); err != nil {
zap.L().Error("Failed to fetch PATs err: %v", zap.Error(err)) zap.L().Error("Failed to fetch PATs err: %v", zap.Error(err))
return nil, model.InternalError(fmt.Errorf("failed to fetch PATs")) return nil, model.InternalError(fmt.Errorf("failed to fetch PATs"))
} }
patsWithUsers := []model.PAT{}
for i := range pats { for i := range pats {
patWithUser := model.PAT{
StorablePersonalAccessToken: pats[i],
}
createdByUser, _ := m.GetUser(ctx, pats[i].UserID) createdByUser, _ := m.GetUser(ctx, pats[i].UserID)
if createdByUser == nil { if createdByUser == nil {
pats[i].CreatedByUser = model.User{ patWithUser.CreatedByUser = model.User{
NotFound: true, NotFound: true,
} }
} else { } else {
pats[i].CreatedByUser = model.User{ patWithUser.CreatedByUser = model.User{
Id: createdByUser.ID, Id: createdByUser.ID,
Name: createdByUser.Name, Name: createdByUser.Name,
Email: createdByUser.Email, Email: createdByUser.Email,
@ -107,11 +92,11 @@ func (m *modelDao) ListPATs(ctx context.Context) ([]model.PAT, basemodel.BaseApi
updatedByUser, _ := m.GetUser(ctx, pats[i].UpdatedByUserID) updatedByUser, _ := m.GetUser(ctx, pats[i].UpdatedByUserID)
if updatedByUser == nil { if updatedByUser == nil {
pats[i].UpdatedByUser = model.User{ patWithUser.UpdatedByUser = model.User{
NotFound: true, NotFound: true,
} }
} else { } else {
pats[i].UpdatedByUser = model.User{ patWithUser.UpdatedByUser = model.User{
Id: updatedByUser.ID, Id: updatedByUser.ID,
Name: updatedByUser.Name, Name: updatedByUser.Name,
Email: updatedByUser.Email, Email: updatedByUser.Email,
@ -120,15 +105,22 @@ func (m *modelDao) ListPATs(ctx context.Context) ([]model.PAT, basemodel.BaseApi
NotFound: false, NotFound: false,
} }
} }
patsWithUsers = append(patsWithUsers, patWithUser)
} }
return pats, nil return patsWithUsers, nil
} }
func (m *modelDao) RevokePAT(ctx context.Context, id string, userID string) basemodel.BaseApiError { func (m *modelDao) RevokePAT(ctx context.Context, orgID string, id string, userID string) basemodel.BaseApiError {
updatedAt := time.Now().Unix() updatedAt := time.Now().Unix()
_, err := m.DB().ExecContext(ctx, _, err := m.DB().NewUpdate().
"UPDATE personal_access_tokens SET revoked=true, updated_by_user_id = $1, updated_at=$2 WHERE id=$3", Model(&types.StorablePersonalAccessToken{}).
userID, updatedAt, id) Set("revoked = ?", true).
Set("updated_by_user_id = ?", userID).
Set("updated_at = ?", updatedAt).
Where("id = ?", id).
Where("org_id = ?", orgID).
Exec(ctx)
if err != nil { if err != nil {
zap.L().Error("Failed to revoke PAT in db, err: %v", zap.Error(err)) zap.L().Error("Failed to revoke PAT in db, err: %v", zap.Error(err))
return model.InternalError(fmt.Errorf("PAT revoke failed")) return model.InternalError(fmt.Errorf("PAT revoke failed"))
@ -137,9 +129,13 @@ func (m *modelDao) RevokePAT(ctx context.Context, id string, userID string) base
} }
func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemodel.BaseApiError) { func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemodel.BaseApiError) {
pats := []model.PAT{} pats := []types.StorablePersonalAccessToken{}
if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE token=? and revoked=false;`, token); err != nil { if err := m.DB().NewSelect().
Model(&pats).
Where("token = ?", token).
Where("revoked = false").
Scan(ctx); err != nil {
return nil, model.InternalError(fmt.Errorf("failed to fetch PAT")) return nil, model.InternalError(fmt.Errorf("failed to fetch PAT"))
} }
@ -150,13 +146,22 @@ func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemo
} }
} }
return &pats[0], nil patWithUser := model.PAT{
StorablePersonalAccessToken: pats[0],
}
return &patWithUser, nil
} }
func (m *modelDao) GetPATByID(ctx context.Context, id string) (*model.PAT, basemodel.BaseApiError) { func (m *modelDao) GetPATByID(ctx context.Context, orgID string, id string) (*model.PAT, basemodel.BaseApiError) {
pats := []model.PAT{} pats := []types.StorablePersonalAccessToken{}
if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE id=? and revoked=false;`, id); err != nil { if err := m.DB().NewSelect().
Model(&pats).
Where("id = ?", id).
Where("org_id = ?", orgID).
Where("revoked = false").
Scan(ctx); err != nil {
return nil, model.InternalError(fmt.Errorf("failed to fetch PAT")) return nil, model.InternalError(fmt.Errorf("failed to fetch PAT"))
} }
@ -167,26 +172,25 @@ func (m *modelDao) GetPATByID(ctx context.Context, id string) (*model.PAT, basem
} }
} }
return &pats[0], nil patWithUser := model.PAT{
StorablePersonalAccessToken: pats[0],
}
return &patWithUser, nil
} }
// deprecated // deprecated
func (m *modelDao) GetUserByPAT(ctx context.Context, token string) (*types.GettableUser, basemodel.BaseApiError) { func (m *modelDao) GetUserByPAT(ctx context.Context, orgID string, token string) (*types.GettableUser, basemodel.BaseApiError) {
users := []types.GettableUser{} users := []types.GettableUser{}
query := `SELECT if err := m.DB().NewSelect().
u.id, Model(&users).
u.name, Column("u.id", "u.name", "u.email", "u.password", "u.created_at", "u.profile_picture_url", "u.org_id", "u.group_id").
u.email, Join("JOIN personal_access_tokens p ON u.id = p.user_id").
u.password, Where("p.token = ?", token).
u.created_at, Where("p.expires_at >= strftime('%s', 'now')").
u.profile_picture_url, Where("p.org_id = ?", orgID).
u.org_id, Scan(ctx); err != nil {
u.group_id
FROM users u, personal_access_tokens p
WHERE u.id = p.user_id and p.token=? and p.expires_at >= strftime('%s', 'now');`
if err := m.DB().Select(&users, query, token); err != nil {
return nil, model.InternalError(fmt.Errorf("failed to fetch user from PAT, err: %v", err)) return nil, model.InternalError(fmt.Errorf("failed to fetch user from PAT, err: %v", err))
} }

View File

@ -1,5 +1,7 @@
package model package model
import "go.signoz.io/signoz/pkg/types"
type User struct { type User struct {
Id string `json:"id" db:"id"` Id string `json:"id" db:"id"`
Name string `json:"name" db:"name"` Name string `json:"name" db:"name"`
@ -16,17 +18,8 @@ type CreatePATRequestBody struct {
} }
type PAT struct { type PAT struct {
Id string `json:"id" db:"id"`
UserID string `json:"userId" db:"user_id"`
CreatedByUser User `json:"createdByUser"` CreatedByUser User `json:"createdByUser"`
UpdatedByUser User `json:"updatedByUser"` UpdatedByUser User `json:"updatedByUser"`
Token string `json:"token" db:"token"`
Role string `json:"role" db:"role"` types.StorablePersonalAccessToken
Name string `json:"name" db:"name"`
CreatedAt int64 `json:"createdAt" db:"created_at"`
ExpiresAt int64 `json:"expiresAt" db:"expires_at"`
UpdatedAt int64 `json:"updatedAt" db:"updated_at"`
LastUsed int64 `json:"lastUsed" db:"last_used"`
Revoked bool `json:"revoked" db:"revoked"`
UpdatedByUserID string `json:"updatedByUserId" db:"updated_by_user_id"`
} }

View File

@ -1,4 +1,4 @@
package model package types
import ( import (
"encoding/json" "encoding/json"
@ -9,12 +9,23 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
saml2 "github.com/russellhaering/gosaml2" saml2 "github.com/russellhaering/gosaml2"
"github.com/uptrace/bun"
"go.signoz.io/signoz/ee/query-service/sso" "go.signoz.io/signoz/ee/query-service/sso"
"go.signoz.io/signoz/ee/query-service/sso/saml" "go.signoz.io/signoz/ee/query-service/sso/saml"
"go.signoz.io/signoz/pkg/types" "go.signoz.io/signoz/pkg/types"
"go.uber.org/zap" "go.uber.org/zap"
) )
type StorableOrgDomain struct {
bun.BaseModel `bun:"table:org_domains"`
types.TimeAuditable
ID uuid.UUID `json:"id" bun:"id,pk,type:text"`
OrgID string `json:"orgId" bun:"org_id,type:text,notnull"`
Name string `json:"name" bun:"name,type:varchar(50),notnull,unique"`
Data string `json:"-" bun:"data,type:text,notnull"`
}
type SSOType string type SSOType string
const ( const (
@ -22,11 +33,10 @@ const (
GoogleAuth SSOType = "GOOGLE_AUTH" GoogleAuth SSOType = "GOOGLE_AUTH"
) )
// OrgDomain identify org owned web domains for auth and other purposes // GettableOrgDomain identify org owned web domains for auth and other purposes
type OrgDomain struct { type GettableOrgDomain struct {
Id uuid.UUID `json:"id"` StorableOrgDomain
Name string `json:"name"`
OrgId string `json:"orgId"`
SsoEnabled bool `json:"ssoEnabled"` SsoEnabled bool `json:"ssoEnabled"`
SsoType SSOType `json:"ssoType"` SsoType SSOType `json:"ssoType"`
@ -36,18 +46,18 @@ type OrgDomain struct {
Org *types.Organization Org *types.Organization
} }
func (od *OrgDomain) String() string { func (od *GettableOrgDomain) String() string {
return fmt.Sprintf("[%s]%s-%s ", od.Name, od.Id.String(), od.SsoType) return fmt.Sprintf("[%s]%s-%s ", od.Name, od.ID.String(), od.SsoType)
} }
// Valid is used a pipeline function to check if org domain // Valid is used a pipeline function to check if org domain
// loaded from db is valid // loaded from db is valid
func (od *OrgDomain) Valid(err error) error { func (od *GettableOrgDomain) Valid(err error) error {
if err != nil { if err != nil {
return err return err
} }
if od.Id == uuid.Nil || od.OrgId == "" { if od.ID == uuid.Nil || od.OrgID == "" {
return fmt.Errorf("both id and orgId are required") return fmt.Errorf("both id and orgId are required")
} }
@ -55,9 +65,9 @@ func (od *OrgDomain) Valid(err error) error {
} }
// ValidNew cheks if the org domain is valid for insertion in db // ValidNew cheks if the org domain is valid for insertion in db
func (od *OrgDomain) ValidNew() error { func (od *GettableOrgDomain) ValidNew() error {
if od.OrgId == "" { if od.OrgID == "" {
return fmt.Errorf("orgId is required") return fmt.Errorf("orgId is required")
} }
@ -69,7 +79,7 @@ func (od *OrgDomain) ValidNew() error {
} }
// LoadConfig loads config params from json text // LoadConfig loads config params from json text
func (od *OrgDomain) LoadConfig(jsondata string) error { func (od *GettableOrgDomain) LoadConfig(jsondata string) error {
d := *od d := *od
err := json.Unmarshal([]byte(jsondata), &d) err := json.Unmarshal([]byte(jsondata), &d)
if err != nil { if err != nil {
@ -79,21 +89,21 @@ func (od *OrgDomain) LoadConfig(jsondata string) error {
return nil return nil
} }
func (od *OrgDomain) GetSAMLEntityID() string { func (od *GettableOrgDomain) GetSAMLEntityID() string {
if od.SamlConfig != nil { if od.SamlConfig != nil {
return od.SamlConfig.SamlEntity return od.SamlConfig.SamlEntity
} }
return "" return ""
} }
func (od *OrgDomain) GetSAMLIdpURL() string { func (od *GettableOrgDomain) GetSAMLIdpURL() string {
if od.SamlConfig != nil { if od.SamlConfig != nil {
return od.SamlConfig.SamlIdp return od.SamlConfig.SamlIdp
} }
return "" return ""
} }
func (od *OrgDomain) GetSAMLCert() string { func (od *GettableOrgDomain) GetSAMLCert() string {
if od.SamlConfig != nil { if od.SamlConfig != nil {
return od.SamlConfig.SamlCert return od.SamlConfig.SamlCert
} }
@ -102,7 +112,7 @@ func (od *OrgDomain) GetSAMLCert() string {
// PrepareGoogleOAuthProvider creates GoogleProvider that is used in // PrepareGoogleOAuthProvider creates GoogleProvider that is used in
// requesting OAuth and also used in processing response from google // requesting OAuth and also used in processing response from google
func (od *OrgDomain) PrepareGoogleOAuthProvider(siteUrl *url.URL) (sso.OAuthCallbackProvider, error) { func (od *GettableOrgDomain) PrepareGoogleOAuthProvider(siteUrl *url.URL) (sso.OAuthCallbackProvider, error) {
if od.GoogleAuthConfig == nil { if od.GoogleAuthConfig == nil {
return nil, fmt.Errorf("GOOGLE OAUTH is not setup correctly for this domain") return nil, fmt.Errorf("GOOGLE OAUTH is not setup correctly for this domain")
} }
@ -111,7 +121,7 @@ func (od *OrgDomain) PrepareGoogleOAuthProvider(siteUrl *url.URL) (sso.OAuthCall
} }
// PrepareSamlRequest creates a request accordingly gosaml2 // PrepareSamlRequest creates a request accordingly gosaml2
func (od *OrgDomain) PrepareSamlRequest(siteUrl *url.URL) (*saml2.SAMLServiceProvider, error) { func (od *GettableOrgDomain) PrepareSamlRequest(siteUrl *url.URL) (*saml2.SAMLServiceProvider, error) {
// this is the url Idp will call after login completion // this is the url Idp will call after login completion
acs := fmt.Sprintf("%s://%s/%s", acs := fmt.Sprintf("%s://%s/%s",
@ -136,9 +146,9 @@ func (od *OrgDomain) PrepareSamlRequest(siteUrl *url.URL) (*saml2.SAMLServicePro
return saml.PrepareRequest(issuer, acs, sourceUrl, od.GetSAMLEntityID(), od.GetSAMLIdpURL(), od.GetSAMLCert()) return saml.PrepareRequest(issuer, acs, sourceUrl, od.GetSAMLEntityID(), od.GetSAMLIdpURL(), od.GetSAMLCert())
} }
func (od *OrgDomain) BuildSsoUrl(siteUrl *url.URL) (ssoUrl string, err error) { func (od *GettableOrgDomain) BuildSsoUrl(siteUrl *url.URL) (ssoUrl string, err error) {
fmtDomainId := strings.Replace(od.Id.String(), "-", ":", -1) fmtDomainId := strings.Replace(od.ID.String(), "-", ":", -1)
// build redirect url from window.location sent by frontend // build redirect url from window.location sent by frontend
redirectURL := fmt.Sprintf("%s://%s%s", siteUrl.Scheme, siteUrl.Host, siteUrl.Path) redirectURL := fmt.Sprintf("%s://%s%s", siteUrl.Scheme, siteUrl.Host, siteUrl.Path)

View File

@ -1,16 +1,15 @@
package model package types
import ( import (
"fmt"
"context" "context"
"fmt"
"net/url" "net/url"
"golang.org/x/oauth2"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"go.signoz.io/signoz/ee/query-service/sso" "go.signoz.io/signoz/ee/query-service/sso"
"golang.org/x/oauth2"
) )
// SamlConfig contans SAML params to generate and respond to the requests
// from SAML provider
type SamlConfig struct { type SamlConfig struct {
SamlEntity string `json:"samlEntity"` SamlEntity string `json:"samlEntity"`
SamlIdp string `json:"samlIdp"` SamlIdp string `json:"samlIdp"`
@ -24,7 +23,6 @@ type GoogleOAuthConfig struct {
RedirectURI string `json:"redirectURI"` RedirectURI string `json:"redirectURI"`
} }
const ( const (
googleIssuerURL = "https://accounts.google.com" googleIssuerURL = "https://accounts.google.com"
) )
@ -65,4 +63,3 @@ func (g *GoogleOAuthConfig) GetProvider(domain string, siteUrl *url.URL) (sso.OA
HostedDomain: domain, HostedDomain: domain,
}, nil }, nil
} }

View File

@ -339,12 +339,15 @@ function APIKeys(): JSX.Element {
? getFormattedTime(APIKey?.lastUsed) ? getFormattedTime(APIKey?.lastUsed)
: 'Never'; : 'Never';
const createdOn = getFormattedTime(APIKey.createdAt); const createdOn = new Date(APIKey.createdAt).toLocaleString();
const expiresIn = const expiresIn =
APIKey.expiresAt === 0 APIKey.expiresAt === 0
? Number.POSITIVE_INFINITY ? Number.POSITIVE_INFINITY
: getDateDifference(APIKey?.createdAt, APIKey?.expiresAt); : getDateDifference(
new Date(APIKey?.createdAt).getTime() / 1000,
APIKey?.expiresAt,
);
const isExpired = isExpiredToken(APIKey.expiresAt); const isExpired = isExpiredToken(APIKey.expiresAt);
@ -354,9 +357,9 @@ function APIKeys(): JSX.Element {
: getFormattedTime(APIKey.expiresAt); : getFormattedTime(APIKey.expiresAt);
const updatedOn = const updatedOn =
!APIKey.updatedAt || APIKey.updatedAt === 0 !APIKey.updatedAt || APIKey.updatedAt === ''
? null ? null
: getFormattedTime(APIKey?.updatedAt); : new Date(APIKey.updatedAt).toLocaleString();
const items: CollapseProps['items'] = [ const items: CollapseProps['items'] = [
{ {
@ -835,7 +838,9 @@ function APIKeys(): JSX.Element {
{activeAPIKey?.createdAt && ( {activeAPIKey?.createdAt && (
<Row> <Row>
<Col span={8}>Created on</Col> <Col span={8}>Created on</Col>
<Col span={16}>{getFormattedTime(activeAPIKey?.createdAt)}</Col> <Col span={16}>
{new Date(activeAPIKey?.createdAt).toLocaleString()}
</Col>
</Row> </Row>
)} )}

View File

@ -13,9 +13,9 @@ export interface APIKeyProps {
role: string; role: string;
token: string; token: string;
id: string; id: string;
createdAt: number; createdAt: string;
createdByUser?: User; createdByUser?: User;
updatedAt?: number; updatedAt?: string;
updatedByUser?: User; updatedByUser?: User;
lastUsed?: number; lastUsed?: number;
} }

View File

@ -551,8 +551,6 @@ func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) {
router.HandleFunc("/api/v1/settings/ttl", am.ViewAccess(aH.getTTL)).Methods(http.MethodGet) router.HandleFunc("/api/v1/settings/ttl", am.ViewAccess(aH.getTTL)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/settings/apdex", am.AdminAccess(aH.setApdexSettings)).Methods(http.MethodPost) router.HandleFunc("/api/v1/settings/apdex", am.AdminAccess(aH.setApdexSettings)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/settings/apdex", am.ViewAccess(aH.getApdexSettings)).Methods(http.MethodGet) router.HandleFunc("/api/v1/settings/apdex", am.ViewAccess(aH.getApdexSettings)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/settings/ingestion_key", am.AdminAccess(aH.insertIngestionKey)).Methods(http.MethodPost)
router.HandleFunc("/api/v1/settings/ingestion_key", am.ViewAccess(aH.getIngestionKeys)).Methods(http.MethodGet)
router.HandleFunc("/api/v2/traces/fields", am.ViewAccess(aH.traceFields)).Methods(http.MethodGet) router.HandleFunc("/api/v2/traces/fields", am.ViewAccess(aH.traceFields)).Methods(http.MethodGet)
router.HandleFunc("/api/v2/traces/fields", am.EditAccess(aH.updateTraceField)).Methods(http.MethodPost) router.HandleFunc("/api/v2/traces/fields", am.EditAccess(aH.updateTraceField)).Methods(http.MethodPost)

View File

@ -1,33 +0,0 @@
package app
import (
"context"
"net/http"
"go.signoz.io/signoz/pkg/query-service/dao"
"go.signoz.io/signoz/pkg/query-service/model"
)
func (aH *APIHandler) insertIngestionKey(w http.ResponseWriter, r *http.Request) {
req, err := parseInsertIngestionKeyRequest(r)
if aH.HandleError(w, err, http.StatusBadRequest) {
return
}
if err := dao.DB().InsertIngestionKey(context.Background(), req); err != nil {
RespondError(w, &model.ApiError{Err: err, Typ: model.ErrorInternal}, nil)
return
}
aH.WriteJSON(w, r, map[string]string{"data": "ingestion key added successfully"})
}
func (aH *APIHandler) getIngestionKeys(w http.ResponseWriter, r *http.Request) {
ingestionKeys, err := dao.DB().GetIngestionKeys(context.Background())
if err != nil {
RespondError(w, &model.ApiError{Err: err, Typ: model.ErrorInternal}, nil)
return
}
aH.WriteJSON(w, r, ingestionKeys)
}

View File

@ -36,8 +36,6 @@ type Queries interface {
GetApdexSettings(ctx context.Context, orgID string, services []string) ([]types.ApdexSettings, *model.ApiError) GetApdexSettings(ctx context.Context, orgID string, services []string) ([]types.ApdexSettings, *model.ApiError)
GetIngestionKeys(ctx context.Context) ([]model.IngestionKey, *model.ApiError)
PrecheckLogin(ctx context.Context, email, sourceUrl string) (*model.PrecheckResponse, model.BaseApiError) PrecheckLogin(ctx context.Context, email, sourceUrl string) (*model.PrecheckResponse, model.BaseApiError)
} }
@ -63,6 +61,4 @@ type Mutations interface {
UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError
SetApdexSettings(ctx context.Context, orgID string, set *types.ApdexSettings) *model.ApiError SetApdexSettings(ctx context.Context, orgID string, set *types.ApdexSettings) *model.ApiError
InsertIngestionKey(ctx context.Context, ingestionKey *model.IngestionKey) *model.ApiError
} }

View File

@ -3,7 +3,6 @@ package sqlite
import ( import (
"context" "context"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/constants"
@ -14,13 +13,12 @@ import (
) )
type ModelDaoSqlite struct { type ModelDaoSqlite struct {
db *sqlx.DB
bundb *bun.DB bundb *bun.DB
} }
// InitDB sets up setting up the connection pool global variable. // InitDB sets up setting up the connection pool global variable.
func InitDB(sqlStore sqlstore.SQLStore) (*ModelDaoSqlite, error) { func InitDB(sqlStore sqlstore.SQLStore) (*ModelDaoSqlite, error) {
mds := &ModelDaoSqlite{db: sqlStore.SQLxDB(), bundb: sqlStore.BunDB()} mds := &ModelDaoSqlite{bundb: sqlStore.BunDB()}
ctx := context.Background() ctx := context.Background()
if err := mds.initializeOrgPreferences(ctx); err != nil { if err := mds.initializeOrgPreferences(ctx); err != nil {
@ -38,8 +36,8 @@ func InitDB(sqlStore sqlstore.SQLStore) (*ModelDaoSqlite, error) {
} }
// DB returns database connection // DB returns database connection
func (mds *ModelDaoSqlite) DB() *sqlx.DB { func (mds *ModelDaoSqlite) DB() *bun.DB {
return mds.db return mds.bundb
} }
// initializeOrgPreferences initializes in-memory telemetry settings. It is planned to have // initializeOrgPreferences initializes in-memory telemetry settings. It is planned to have

View File

@ -1,39 +0,0 @@
package sqlite
import (
"context"
"go.signoz.io/signoz/pkg/query-service/model"
)
func (mds *ModelDaoSqlite) GetIngestionKeys(ctx context.Context) ([]model.IngestionKey, *model.ApiError) {
ingestion_keys := []model.IngestionKey{}
err := mds.db.Select(&ingestion_keys, `SELECT * FROM ingestion_keys`)
if err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return ingestion_keys, nil
}
func (mds *ModelDaoSqlite) InsertIngestionKey(ctx context.Context, ingestion_key *model.IngestionKey) *model.ApiError {
_, err := mds.db.ExecContext(ctx, `
INSERT INTO ingestion_keys (
ingestion_key,
name,
key_id,
ingestion_url,
data_region
) VALUES (
?,
?,
?,
?,
?
)`, ingestion_key.IngestionKey, ingestion_key.Name, ingestion_key.KeyId, ingestion_key.IngestionURL, ingestion_key.DataRegion)
if err != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return nil
}

View File

@ -49,6 +49,7 @@ func NewTestSqliteDB(t *testing.T) (sqlStore sqlstore.SQLStore, testDBFilePath s
sqlmigration.NewModifyOrgDomainFactory(), sqlmigration.NewModifyOrgDomainFactory(),
sqlmigration.NewUpdateOrganizationFactory(sqlStore), sqlmigration.NewUpdateOrganizationFactory(sqlStore),
sqlmigration.NewUpdateDashboardAndSavedViewsFactory(sqlStore), sqlmigration.NewUpdateDashboardAndSavedViewsFactory(sqlStore),
sqlmigration.NewUpdatePatAndOrgDomainsFactory(sqlStore),
), ),
) )
if err != nil { if err != nil {

View File

@ -58,8 +58,9 @@ func NewSQLMigrationProviderFactories(sqlstore sqlstore.SQLStore) factory.NamedM
sqlmigration.NewModifyDatetimeFactory(), sqlmigration.NewModifyDatetimeFactory(),
sqlmigration.NewModifyOrgDomainFactory(), sqlmigration.NewModifyOrgDomainFactory(),
sqlmigration.NewUpdateOrganizationFactory(sqlstore), sqlmigration.NewUpdateOrganizationFactory(sqlstore),
sqlmigration.NewAddAlertmanagerFactory(), sqlmigration.NewAddAlertmanagerFactory(sqlstore),
sqlmigration.NewUpdateDashboardAndSavedViewsFactory(sqlstore), sqlmigration.NewUpdateDashboardAndSavedViewsFactory(sqlstore),
sqlmigration.NewUpdatePatAndOrgDomainsFactory(sqlstore),
) )
} }

View File

@ -14,17 +14,24 @@ import (
"github.com/uptrace/bun/migrate" "github.com/uptrace/bun/migrate"
"go.signoz.io/signoz/pkg/alertmanager/alertmanagerserver" "go.signoz.io/signoz/pkg/alertmanager/alertmanagerserver"
"go.signoz.io/signoz/pkg/factory" "go.signoz.io/signoz/pkg/factory"
"go.signoz.io/signoz/pkg/sqlstore"
"go.signoz.io/signoz/pkg/types/alertmanagertypes" "go.signoz.io/signoz/pkg/types/alertmanagertypes"
) )
type addAlertmanager struct{} type addAlertmanager struct {
store sqlstore.SQLStore
func NewAddAlertmanagerFactory() factory.ProviderFactory[SQLMigration, Config] {
return factory.NewProviderFactory(factory.MustNewName("add_alertmanager"), newAddAlertmanager)
} }
func newAddAlertmanager(_ context.Context, _ factory.ProviderSettings, _ Config) (SQLMigration, error) { func NewAddAlertmanagerFactory(store sqlstore.SQLStore) factory.ProviderFactory[SQLMigration, Config] {
return &addAlertmanager{}, nil return factory.NewProviderFactory(factory.MustNewName("add_alertmanager"), func(ctx context.Context, ps factory.ProviderSettings, c Config) (SQLMigration, error) {
return newAddAlertmanager(ctx, ps, c, store)
})
}
func newAddAlertmanager(_ context.Context, _ factory.ProviderSettings, _ Config, store sqlstore.SQLStore) (SQLMigration, error) {
return &addAlertmanager{
store: store,
}, nil
} }
func (migration *addAlertmanager) Register(migrations *migrate.Migrations) error { func (migration *addAlertmanager) Register(migrations *migrate.Migrations) error {
@ -53,14 +60,17 @@ func (migration *addAlertmanager) Up(ctx context.Context, db *bun.DB) error {
} }
} }
if exists, err := migration.store.Dialect().ColumnExists(ctx, tx, "notification_channels", "org_id"); err != nil {
return err
} else if !exists {
if _, err := tx. if _, err := tx.
NewAddColumn(). NewAddColumn().
Table("notification_channels"). Table("notification_channels").
ColumnExpr("org_id"). ColumnExpr("org_id TEXT REFERENCES organizations(id) ON DELETE CASCADE").
Apply(WrapIfNotExists(ctx, db, "notification_channels", "org_id")).
Exec(ctx); err != nil && err != ErrNoExecute { Exec(ctx); err != nil && err != ErrNoExecute {
return err return err
} }
}
if _, err := tx. if _, err := tx.
NewCreateTable(). NewCreateTable().

View File

@ -0,0 +1,131 @@
package sqlmigration
import (
"context"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
"go.signoz.io/signoz/pkg/factory"
"go.signoz.io/signoz/pkg/sqlstore"
"go.signoz.io/signoz/pkg/types"
)
type updatePatAndOrgDomains struct {
store sqlstore.SQLStore
}
func NewUpdatePatAndOrgDomainsFactory(sqlstore sqlstore.SQLStore) factory.ProviderFactory[SQLMigration, Config] {
return factory.NewProviderFactory(factory.MustNewName("update_pat_and_org_domains"), func(ctx context.Context, ps factory.ProviderSettings, c Config) (SQLMigration, error) {
return newUpdatePatAndOrgDomains(ctx, ps, c, sqlstore)
})
}
func newUpdatePatAndOrgDomains(_ context.Context, _ factory.ProviderSettings, _ Config, store sqlstore.SQLStore) (SQLMigration, error) {
return &updatePatAndOrgDomains{
store: store,
}, nil
}
func (migration *updatePatAndOrgDomains) Register(migrations *migrate.Migrations) error {
if err := migrations.Register(migration.Up, migration.Down); err != nil {
return err
}
return nil
}
func (migration *updatePatAndOrgDomains) Up(ctx context.Context, db *bun.DB) error {
// begin transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
// get all org ids
var orgIDs []string
if err := tx.NewSelect().Model((*types.Organization)(nil)).Column("id").Scan(ctx, &orgIDs); err != nil {
return err
}
// add org id to pat and org_domains table
if exists, err := migration.store.Dialect().ColumnExists(ctx, tx, "personal_access_tokens", "org_id"); err != nil {
return err
} else if !exists {
if _, err := tx.NewAddColumn().Table("personal_access_tokens").ColumnExpr("org_id TEXT REFERENCES organizations(id) ON DELETE CASCADE").Exec(ctx); err != nil {
return err
}
// check if there is one org ID if yes then set it to all personal_access_tokens.
if len(orgIDs) == 1 {
orgID := orgIDs[0]
if _, err := tx.NewUpdate().Table("personal_access_tokens").Set("org_id = ?", orgID).Where("org_id IS NULL").Exec(ctx); err != nil {
return err
}
}
}
if err := updateOrgId(ctx, tx, "org_domains"); err != nil {
return err
}
// change created_at and updated_at from integer to timestamp
for _, table := range []string{"personal_access_tokens", "org_domains"} {
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil {
return err
}
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "updated_at"); err != nil {
return err
}
}
// drop table if exists ingestion_keys
if _, err := tx.NewDropTable().IfExists().Table("ingestion_keys").Exec(ctx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (migration *updatePatAndOrgDomains) Down(ctx context.Context, db *bun.DB) error {
return nil
}
func updateOrgId(ctx context.Context, tx bun.Tx, table string) error {
if _, err := tx.NewCreateTable().
Model(&struct {
bun.BaseModel `bun:"table:org_domains_new"`
ID string `bun:"id,pk,type:text"`
OrgID string `bun:"org_id,type:text,notnull"`
Name string `bun:"name,type:varchar(50),notnull,unique"`
CreatedAt int `bun:"created_at,notnull"`
UpdatedAt int `bun:"updated_at"`
Data string `bun:"data,type:text,notnull"`
}{}).
ForeignKey(`("org_id") REFERENCES "organizations" ("id") ON DELETE CASCADE`).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// copy data from org_domains to org_domains_new
if _, err := tx.ExecContext(ctx, `INSERT INTO org_domains_new (id, org_id, name, created_at, updated_at, data) SELECT id, org_id, name, created_at, updated_at, data FROM org_domains`); err != nil {
return err
}
// delete old table
if _, err := tx.NewDropTable().IfExists().Table("org_domains").Exec(ctx); err != nil {
return err
}
// rename new table to org_domains
if _, err := tx.ExecContext(ctx, `ALTER TABLE org_domains_new RENAME TO org_domains`); err != nil {
return err
}
return nil
}

View File

@ -2,7 +2,6 @@ package sqlmigration
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -62,31 +61,6 @@ func MustNew(
return migrations return migrations
} }
func WrapIfNotExists(ctx context.Context, db *bun.DB, table string, column string) func(q *bun.AddColumnQuery) *bun.AddColumnQuery {
return func(q *bun.AddColumnQuery) *bun.AddColumnQuery {
if db.Dialect().Name() != dialect.SQLite {
return q.IfNotExists()
}
var result string
err := db.
NewSelect().
ColumnExpr("name").
Table("pragma_table_info").
Where("arg = ?", table).
Where("name = ?", column).
Scan(ctx, &result)
if err != nil {
if err == sql.ErrNoRows {
return q
}
return q.Err(err)
}
return q.Err(ErrNoExecute)
}
}
func GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) { func GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string var columnType string
var err error var err error

View File

@ -4,29 +4,18 @@ import (
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
type PersonalAccessToken struct { type StorablePersonalAccessToken struct {
bun.BaseModel `bun:"table:personal_access_tokens"` bun.BaseModel `bun:"table:personal_access_tokens"`
ID int `bun:"id,pk,autoincrement"` TimeAuditable
Role string `bun:"role,type:text,notnull,default:'ADMIN'"` OrgID string `json:"orgId" bun:"org_id,type:text,notnull"`
UserID string `bun:"user_id,type:text,notnull"` ID int `json:"id" bun:"id,pk,autoincrement"`
Token string `bun:"token,type:text,notnull,unique"` Role string `json:"role" bun:"role,type:text,notnull,default:'ADMIN'"`
Name string `bun:"name,type:text,notnull"` UserID string `json:"userId" bun:"user_id,type:text,notnull"`
CreatedAt int `bun:"created_at,notnull,default:0"` Token string `json:"token" bun:"token,type:text,notnull,unique"`
ExpiresAt int `bun:"expires_at,notnull,default:0"` Name string `json:"name" bun:"name,type:text,notnull"`
UpdatedAt int `bun:"updated_at,notnull,default:0"` ExpiresAt int64 `json:"expiresAt" bun:"expires_at,notnull,default:0"`
LastUsed int `bun:"last_used,notnull,default:0"` LastUsed int64 `json:"lastUsed" bun:"last_used,notnull,default:0"`
Revoked bool `bun:"revoked,notnull,default:false"` Revoked bool `json:"revoked" bun:"revoked,notnull,default:false"`
UpdatedByUserID string `bun:"updated_by_user_id,type:text,notnull,default:''"` UpdatedByUserID string `json:"updatedByUserId" bun:"updated_by_user_id,type:text,notnull,default:''"`
}
type OrgDomain struct {
bun.BaseModel `bun:"table:org_domains"`
ID string `bun:"id,pk,type:text"`
OrgID string `bun:"org_id,type:text,notnull"`
Name string `bun:"name,type:varchar(50),notnull,unique"`
CreatedAt int `bun:"created_at,notnull"`
UpdatedAt int `bun:"updated_at,type:timestamp"`
Data string `bun:"data,type:text,notnull"`
} }