diff --git a/.devenv/docker/postgres/compose.yaml b/.devenv/docker/postgres/compose.yaml new file mode 100644 index 0000000000..7b08a021ef --- /dev/null +++ b/.devenv/docker/postgres/compose.yaml @@ -0,0 +1,27 @@ +services: + + postgres: + image: postgres:15 + container_name: postgres + environment: + POSTGRES_DB: signoz + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + healthcheck: + test: + [ + "CMD", + "pg_isready", + "-d", + "signoz", + "-U", + "postgres" + ] + interval: 30s + timeout: 30s + retries: 3 + restart: on-failure + ports: + - "127.0.0.1:5432:5432/tcp" + volumes: + - ${PWD}/fs/tmp/var/lib/postgresql/data/:/var/lib/postgresql/data/ \ No newline at end of file diff --git a/.github/workflows/integrationci.yaml b/.github/workflows/integrationci.yaml index b27278c93a..77e032ba42 100644 --- a/.github/workflows/integrationci.yaml +++ b/.github/workflows/integrationci.yaml @@ -44,10 +44,8 @@ jobs: - name: run run: | cd tests/integration && \ - poetry run pytest -ra \ + poetry run pytest \ --basetemp=./tmp/ \ - -vv \ - --capture=no \ src/${{matrix.src}} \ --sqlstore-provider ${{matrix.sqlstore-provider}} \ --postgres-version ${{matrix.postgres-version}} \ diff --git a/Makefile b/Makefile index 003d7e8d64..1dc3c0a8ef 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,11 @@ devenv-clickhouse: ## Run clickhouse in devenv @cd .devenv/docker/clickhouse; \ docker compose -f compose.yaml up -d +.PHONY: devenv-postgres +devenv-postgres: ## Run postgres in devenv + @cd .devenv/docker/postgres; \ + docker compose -f compose.yaml up -d + ############################################################## # go commands ############################################################## diff --git a/ee/http/middleware/pat.go b/ee/http/middleware/pat.go index 399d64aad5..0f4ea5d755 100644 --- a/ee/http/middleware/pat.go +++ b/ee/http/middleware/pat.go @@ -61,11 +61,17 @@ func (p *Pat) Wrap(next http.Handler) http.Handler { return } + role, err := authtypes.NewRole(user.Role) + if err != nil { + next.ServeHTTP(w, r) + return + } + jwt := authtypes.Claims{ - UserID: user.ID, - GroupID: user.GroupID, - Email: user.Email, - OrgID: user.OrgID, + UserID: user.ID, + Role: role, + Email: user.Email, + OrgID: user.OrgID, } ctx = authtypes.NewContextWithClaims(ctx, jwt) diff --git a/ee/query-service/app/api/api.go b/ee/query-service/app/api/api.go index 00e552d07b..fa9f9ac797 100644 --- a/ee/query-service/app/api/api.go +++ b/ee/query-service/app/api/api.go @@ -12,6 +12,7 @@ import ( "github.com/SigNoz/signoz/ee/query-service/usage" "github.com/SigNoz/signoz/pkg/alertmanager" "github.com/SigNoz/signoz/pkg/apis/fields" + "github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/modules/preference" preferencecore "github.com/SigNoz/signoz/pkg/modules/preference/core" @@ -126,7 +127,7 @@ func (ah *APIHandler) CheckFeature(f string) bool { } // RegisterRoutes registers routes for this handler on the given router -func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddleware) { +func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *middleware.AuthZ) { // note: add ee override methods first // routes available only in ee version @@ -199,7 +200,7 @@ func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddlew } -func (ah *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *baseapp.AuthMiddleware) { +func (ah *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *middleware.AuthZ) { ah.APIHandler.RegisterCloudIntegrationsRoutes(router, am) diff --git a/ee/query-service/app/api/cloudIntegrations.go b/ee/query-service/app/api/cloudIntegrations.go index 90714beff3..5d00741b6a 100644 --- a/ee/query-service/app/api/cloudIntegrations.go +++ b/ee/query-service/app/api/cloudIntegrations.go @@ -12,11 +12,11 @@ import ( "github.com/SigNoz/signoz/ee/query-service/constants" eeTypes "github.com/SigNoz/signoz/ee/types" + "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/query-service/auth" - baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants" - "github.com/SigNoz/signoz/pkg/query-service/dao" basemodel "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/google/uuid" "github.com/gorilla/mux" "go.uber.org/zap" @@ -30,6 +30,12 @@ type CloudIntegrationConnectionParamsResponse struct { } func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseWriter, r *http.Request) { + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(w, err) + return + } + cloudProvider := mux.Vars(r)["cloudProvider"] if cloudProvider != "aws" { RespondError(w, basemodel.BadRequest(fmt.Errorf( @@ -38,15 +44,7 @@ func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseW return } - currentUser, err := auth.GetUserFromReqContext(r.Context()) - if err != nil { - RespondError(w, basemodel.UnauthorizedError(fmt.Errorf( - "couldn't deduce current user: %w", err, - )), nil) - return - } - - apiKey, apiErr := ah.getOrCreateCloudIntegrationPAT(r.Context(), currentUser.OrgID, cloudProvider) + apiKey, apiErr := ah.getOrCreateCloudIntegrationPAT(r.Context(), claims.OrgID, cloudProvider) if apiErr != nil { RespondError(w, basemodel.WrapApiError( apiErr, "couldn't provision PAT for cloud integration:", @@ -137,7 +135,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationPAT(ctx context.Context, orgId newPAT := eeTypes.NewGettablePAT( integrationPATName, - baseconstants.ViewerGroup, + authtypes.RoleViewer.String(), integrationUser.ID, 0, ) @@ -181,11 +179,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationUser( OrgID: orgId, } - viewerGroup, apiErr := dao.DB().GetGroupByName(ctx, baseconstants.ViewerGroup) - if apiErr != nil { - return nil, basemodel.WrapApiError(apiErr, "couldn't get viewer group for creating integration user") - } - newUser.GroupID = viewerGroup.ID + newUser.Role = authtypes.RoleViewer.String() passwordHash, err := auth.PasswordHash(uuid.NewString()) if err != nil { diff --git a/ee/query-service/app/api/dashboard.go b/ee/query-service/app/api/dashboard.go index 7a9a75bcc6..f43db74be0 100644 --- a/ee/query-service/app/api/dashboard.go +++ b/ee/query-service/app/api/dashboard.go @@ -7,7 +7,6 @@ import ( "github.com/SigNoz/signoz/pkg/errors" "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/query-service/app/dashboards" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/gorilla/mux" ) @@ -36,18 +35,19 @@ func (ah *APIHandler) lockUnlockDashboard(w http.ResponseWriter, r *http.Request return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { render.Error(w, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) return } + dashboard, err := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid) if err != nil { render.Error(w, errors.Wrapf(err, errors.TypeInternal, errors.CodeInternal, "failed to get dashboard")) return } - if !auth.IsAdminV2(claims) && (dashboard.CreatedBy != claims.Email) { + if err := claims.IsAdmin(); err != nil && (dashboard.CreatedBy != claims.Email) { render.Error(w, errors.Newf(errors.TypeForbidden, errors.CodeForbidden, "You are not authorized to lock/unlock this dashboard")) return } diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go index b852a3be4e..185dba8ff5 100644 --- a/ee/query-service/app/api/pat.go +++ b/ee/query-service/app/api/pat.go @@ -1,7 +1,6 @@ package api import ( - "context" "encoding/json" "fmt" "net/http" @@ -13,35 +12,31 @@ import ( "github.com/SigNoz/signoz/pkg/errors" errorsV2 "github.com/SigNoz/signoz/pkg/errors" "github.com/SigNoz/signoz/pkg/http/render" - "github.com/SigNoz/signoz/pkg/query-service/auth" - baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants" basemodel "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/valuer" "github.com/gorilla/mux" "go.uber.org/zap" ) func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(w, err) + return + } req := model.CreatePATRequestBody{} if err := json.NewDecoder(r.Body).Decode(&req); err != nil { RespondError(w, model.BadRequest(err), nil) return } - user, err := auth.GetUserFromReqContext(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } + pat := eeTypes.NewGettablePAT( req.Name, req.Role, - user.ID, + claims.UserID, req.ExpiresInDays, ) err = validatePATRequest(pat) @@ -52,7 +47,7 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { zap.L().Info("Got Create PAT request", zap.Any("pat", pat)) var apierr basemodel.BaseApiError - if pat, apierr = ah.AppDao().CreatePAT(ctx, user.OrgID, pat); apierr != nil { + if pat, apierr = ah.AppDao().CreatePAT(r.Context(), claims.OrgID, pat); apierr != nil { RespondError(w, apierr, nil) return } @@ -61,20 +56,28 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { } func validatePATRequest(req eeTypes.GettablePAT) error { - if req.Role == "" || (req.Role != baseconstants.ViewerGroup && req.Role != baseconstants.EditorGroup && req.Role != baseconstants.AdminGroup) { - return fmt.Errorf("valid role is required") + _, err := authtypes.NewRole(req.Role) + if err != nil { + return err } + if req.ExpiresAt < 0 { return fmt.Errorf("valid expiresAt is required") } + if req.Name == "" { return fmt.Errorf("valid name is required") } + return nil } func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(w, err) + return + } req := eeTypes.GettablePAT{} if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -89,24 +92,15 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { return } - user, err := auth.GetUserFromReqContext(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - //get the pat - existingPAT, paterr := ah.AppDao().GetPATByID(ctx, user.OrgID, id) + existingPAT, paterr := ah.AppDao().GetPATByID(r.Context(), claims.OrgID, id) if paterr != nil { render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error())) return } // get the user - createdByUser, usererr := ah.AppDao().GetUser(ctx, existingPAT.UserID) + createdByUser, usererr := ah.AppDao().GetUser(r.Context(), existingPAT.UserID) if usererr != nil { render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error())) return @@ -123,11 +117,11 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { return } - req.UpdatedByUserID = user.ID + req.UpdatedByUserID = claims.UserID req.UpdatedAt = time.Now() zap.L().Info("Got Update PAT request", zap.Any("pat", req)) var apierr basemodel.BaseApiError - if apierr = ah.AppDao().UpdatePAT(ctx, user.OrgID, req, id); apierr != nil { + if apierr = ah.AppDao().UpdatePAT(r.Context(), claims.OrgID, req, id); apierr != nil { RespondError(w, apierr, nil) return } @@ -136,50 +130,44 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { } func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() - user, err := auth.GetUserFromReqContext(r.Context()) + claims, err := authtypes.ClaimsFromContext(r.Context()) if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) + render.Error(w, err) return } - zap.L().Info("Get PATs for user", zap.String("user_id", user.ID)) - pats, apierr := ah.AppDao().ListPATs(ctx, user.OrgID) + + pats, apierr := ah.AppDao().ListPATs(r.Context(), claims.OrgID) if apierr != nil { RespondError(w, apierr, nil) return } + ah.Respond(w, pats) } func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(w, err) + return + } + idStr := mux.Vars(r)["id"] id, err := valuer.NewUUID(idStr) if err != nil { render.Error(w, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "id is not a valid uuid-v7")) return } - user, err := auth.GetUserFromReqContext(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } //get the pat - existingPAT, paterr := ah.AppDao().GetPATByID(ctx, user.OrgID, id) + existingPAT, paterr := ah.AppDao().GetPATByID(r.Context(), claims.OrgID, id) if paterr != nil { render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error())) return } // get the user - createdByUser, usererr := ah.AppDao().GetUser(ctx, existingPAT.UserID) + createdByUser, usererr := ah.AppDao().GetUser(r.Context(), existingPAT.UserID) if usererr != nil { render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error())) return @@ -191,7 +179,7 @@ func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) { } zap.L().Info("Revoke PAT with id", zap.String("id", id.StringValue())) - if apierr := ah.AppDao().RevokePAT(ctx, user.OrgID, id, user.ID); apierr != nil { + if apierr := ah.AppDao().RevokePAT(r.Context(), claims.OrgID, id, claims.UserID); apierr != nil { RespondError(w, apierr, nil) return } diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index c9af947bc8..455d1419b0 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -2,7 +2,6 @@ package app import ( "context" - "errors" "fmt" "net" "net/http" @@ -22,11 +21,9 @@ import ( "github.com/SigNoz/signoz/pkg/alertmanager" "github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/prometheus" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/telemetrystore" - "github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/web" "github.com/rs/cors" @@ -334,24 +331,8 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, } func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*http.Server, error) { - r := baseapp.NewRouter() - - // add auth middleware - getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) { - user, err := auth.GetUserFromReqContext(ctx) - - if err != nil { - return nil, err - } - - if user.User.OrgID == "" { - return nil, basemodel.UnauthorizedError(errors.New("orgId is missing in the claims")) - } - - return user, nil - } - am := baseapp.NewAuthMiddleware(getUserFromRequest) + am := middleware.NewAuthZ(s.serverOptions.SigNoz.Instrumentation.Logger()) r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) r.Use(eemiddleware.NewPat(s.serverOptions.SigNoz.SQLStore, []string{"SIGNOZ-API-KEY"}).Wrap) diff --git a/ee/query-service/dao/sqlite/auth.go b/ee/query-service/dao/sqlite/auth.go index a651fabfea..7e4af543aa 100644 --- a/ee/query-service/dao/sqlite/auth.go +++ b/ee/query-service/dao/sqlite/auth.go @@ -9,7 +9,6 @@ import ( "github.com/SigNoz/signoz/ee/query-service/constants" "github.com/SigNoz/signoz/ee/query-service/model" baseauth "github.com/SigNoz/signoz/pkg/query-service/auth" - baseconst "github.com/SigNoz/signoz/pkg/query-service/constants" basemodel "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/types" @@ -36,12 +35,6 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) ( return nil, model.InternalErrorStr("failed to generate password hash") } - group, apiErr := m.GetGroupByName(ctx, baseconst.ViewerGroup) - if apiErr != nil { - zap.L().Error("GetGroupByName failed", zap.Error(apiErr)) - return nil, apiErr - } - user := &types.User{ ID: uuid.New().String(), Name: "", @@ -51,11 +44,11 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) ( CreatedAt: time.Now(), }, ProfilePictureURL: "", // Currently unused - GroupID: group.ID, + Role: authtypes.RoleViewer.String(), OrgID: domain.OrgID, } - user, apiErr = m.CreateUser(ctx, user, false) + user, apiErr := m.CreateUser(ctx, user, false) if apiErr != nil { zap.L().Error("CreateUser failed", zap.Error(apiErr)) return nil, apiErr @@ -115,7 +108,7 @@ func (m *modelDao) CanUsePassword(ctx context.Context, email string) (bool, base return false, baseapierr } - if userPayload.Role != baseconst.AdminGroup { + if userPayload.Role != authtypes.RoleAdmin.String() { return false, model.BadRequest(fmt.Errorf("auth method not supported")) } diff --git a/ee/query-service/license/manager.go b/ee/query-service/license/manager.go index c67dcf7a25..c9c5d2e289 100644 --- a/ee/query-service/license/manager.go +++ b/ee/query-service/license/manager.go @@ -239,8 +239,8 @@ func (lm *Manager) ValidateV3(ctx context.Context) (reterr error) { func (lm *Manager) ActivateV3(ctx context.Context, licenseKey string) (licenseResponse *model.LicenseV3, errResponse *model.ApiError) { defer func() { if errResponse != nil { - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_LICENSE_ACT_FAILED, map[string]interface{}{"err": errResponse.Err.Error()}, claims.Email, true, false) } diff --git a/ee/query-service/main.go b/ee/query-service/main.go index d8e25e469c..ce2c1edf83 100644 --- a/ee/query-service/main.go +++ b/ee/query-service/main.go @@ -11,7 +11,6 @@ import ( "github.com/SigNoz/signoz/pkg/config" "github.com/SigNoz/signoz/pkg/config/envprovider" "github.com/SigNoz/signoz/pkg/config/fileprovider" - "github.com/SigNoz/signoz/pkg/query-service/auth" baseconst "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/sqlstore/sqlstorehook" @@ -147,10 +146,6 @@ func main() { zap.L().Fatal("Could not start server", zap.Error(err)) } - if err := auth.InitAuthCache(context.Background()); err != nil { - zap.L().Fatal("Failed to initialize auth cache", zap.Error(err)) - } - signoz.Start(context.Background()) if err := signoz.Wait(context.Background()); err != nil { diff --git a/ee/sqlstore/postgressqlstore/dialect.go b/ee/sqlstore/postgressqlstore/dialect.go index 1a7ab7eb50..2fa2c0084e 100644 --- a/ee/sqlstore/postgressqlstore/dialect.go +++ b/ee/sqlstore/postgressqlstore/dialect.go @@ -28,10 +28,9 @@ var ( CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE` ) -type dialect struct { -} +type dialect struct{} -func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { +func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { columnType, err := dialect.GetColumnType(ctx, bun, table, column) if err != nil { return err @@ -78,7 +77,7 @@ func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, return nil } -func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { +func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { columnExists, err := dialect.ColumnExists(ctx, bun, table, column) if err != nil { return err @@ -420,3 +419,26 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel return nil } + +func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bunIDB bun.IDB, model interface{}, column string) error { + existingTable := bunIDB.Dialect().Tables().Get(reflect.TypeOf(model)) + columnExists, err := dialect.ColumnExists(ctx, bunIDB, existingTable.Name, column) + if err != nil { + return err + } + + if !columnExists { + return nil + } + + _, err = bunIDB. + NewDropColumn(). + Model(model). + Column(column). + Exec(ctx) + if err != nil { + return err + } + + return nil +} diff --git a/frontend/.gitignore b/frontend/.gitignore index 7d7c7a5f2d..a548963940 100644 --- a/frontend/.gitignore +++ b/frontend/.gitignore @@ -1,3 +1,4 @@ # Sentry Config File .env.sentry-build-plugin +.qodo diff --git a/frontend/src/providers/App/utils.ts b/frontend/src/providers/App/utils.ts index 798362c79c..9654a910c7 100644 --- a/frontend/src/providers/App/utils.ts +++ b/frontend/src/providers/App/utils.ts @@ -23,7 +23,6 @@ function getUserDefaults(): IUser { organization: '', orgId: '', role: 'VIEWER', - groupId: '', }; } diff --git a/frontend/src/tests/test-utils.tsx b/frontend/src/tests/test-utils.tsx index ec39c99ab7..c11f8a7cb6 100644 --- a/frontend/src/tests/test-utils.tsx +++ b/frontend/src/tests/test-utils.tsx @@ -151,7 +151,6 @@ export function getAppContextMock( organization: 'Nightswatch', orgId: 'does-not-matter-id', role: role as ROLES, - groupId: 'does-not-matter-groupId', }, org: [ { diff --git a/frontend/src/types/api/user/getUser.ts b/frontend/src/types/api/user/getUser.ts index be88234ff3..e066ec3103 100644 --- a/frontend/src/types/api/user/getUser.ts +++ b/frontend/src/types/api/user/getUser.ts @@ -15,5 +15,4 @@ export interface PayloadProps { profilePictureURL: string; organization: string; role: ROLES; - groupId: string; } diff --git a/pkg/alertmanager/api.go b/pkg/alertmanager/api.go index 95183936e3..ece7dcfa37 100644 --- a/pkg/alertmanager/api.go +++ b/pkg/alertmanager/api.go @@ -28,9 +28,9 @@ func (api *API) GetAlerts(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } @@ -53,9 +53,9 @@ func (api *API) TestReceiver(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } @@ -85,9 +85,9 @@ func (api *API) ListChannels(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } @@ -122,9 +122,9 @@ func (api *API) GetChannelByID(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } @@ -159,9 +159,9 @@ func (api *API) UpdateChannelByID(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } @@ -209,9 +209,9 @@ func (api *API) DeleteChannelByID(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } @@ -246,9 +246,9 @@ func (api *API) CreateChannel(rw http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) defer cancel() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + render.Error(rw, err) return } diff --git a/pkg/http/middleware/analytics.go b/pkg/http/middleware/analytics.go index dd0104459c..0930935bbc 100644 --- a/pkg/http/middleware/analytics.go +++ b/pkg/http/middleware/analytics.go @@ -46,8 +46,8 @@ func (a *Analytics) Wrap(next http.Handler) http.Handler { } if _, ok := telemetry.EnabledPaths()[path]; ok { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false) } } @@ -134,8 +134,8 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str data["queryType"] = queryInfoResult.QueryType data["panelType"] = queryInfoResult.PanelType - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err == nil { // switch case to set data["screen"] based on the referrer switch { case dashboardMatched: diff --git a/pkg/http/middleware/auth.go b/pkg/http/middleware/auth.go index a9626c4fa3..719d66bdf1 100644 --- a/pkg/http/middleware/auth.go +++ b/pkg/http/middleware/auth.go @@ -28,9 +28,7 @@ func (a *Auth) Wrap(next http.Handler) http.Handler { values = append(values, r.Header.Get(header)) } - ctx, err := a.jwt.ContextFromRequest( - r.Context(), - values...) + ctx, err := a.jwt.ContextFromRequest(r.Context(), values...) if err != nil { next.ServeHTTP(w, r) return diff --git a/pkg/http/middleware/authz.go b/pkg/http/middleware/authz.go new file mode 100644 index 0000000000..60ef94aabc --- /dev/null +++ b/pkg/http/middleware/authz.go @@ -0,0 +1,105 @@ +package middleware + +import ( + "log/slog" + "net/http" + + "github.com/SigNoz/signoz/pkg/http/render" + "github.com/SigNoz/signoz/pkg/types/authtypes" + "github.com/gorilla/mux" +) + +const ( + authzDeniedMessage string = "::AUTHZ-DENIED::" +) + +type AuthZ struct { + logger *slog.Logger +} + +func NewAuthZ(logger *slog.Logger) *AuthZ { + if logger == nil { + panic("cannot build authz middleware, logger is empty") + } + + return &AuthZ{logger: logger} +} + +func (middleware *AuthZ) ViewAccess(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + claims, err := authtypes.ClaimsFromContext(req.Context()) + if err != nil { + render.Error(rw, err) + return + } + + if err := claims.IsViewer(); err != nil { + middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims) + render.Error(rw, err) + return + } + + next(rw, req) + }) +} + +func (middleware *AuthZ) EditAccess(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + claims, err := authtypes.ClaimsFromContext(req.Context()) + if err != nil { + render.Error(rw, err) + return + } + + if err := claims.IsEditor(); err != nil { + middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims) + render.Error(rw, err) + return + } + + next(rw, req) + }) +} + +func (middleware *AuthZ) AdminAccess(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + claims, err := authtypes.ClaimsFromContext(req.Context()) + if err != nil { + render.Error(rw, err) + return + } + + if err := claims.IsAdmin(); err != nil { + middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims) + render.Error(rw, err) + return + } + + next(rw, req) + }) +} + +func (middleware *AuthZ) SelfAccess(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + claims, err := authtypes.ClaimsFromContext(req.Context()) + if err != nil { + render.Error(rw, err) + return + } + + id := mux.Vars(req)["id"] + if err := claims.IsSelfAccess(id); err != nil { + middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims) + render.Error(rw, err) + return + } + + next(rw, req) + }) +} + +func (middleware *AuthZ) OpenAccess(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + next(rw, req) + }) +} diff --git a/pkg/http/middleware/logging.go b/pkg/http/middleware/logging.go index a9c666418d..61dbbab67d 100644 --- a/pkg/http/middleware/logging.go +++ b/pkg/http/middleware/logging.go @@ -136,8 +136,8 @@ func (middleware *Logging) getLogCommentKVs(r *http.Request) map[string]string { } var email string - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err == nil { email = claims.Email } diff --git a/pkg/http/render/render.go b/pkg/http/render/render.go index 63f6d07603..a3021e4b48 100644 --- a/pkg/http/render/render.go +++ b/pkg/http/render/render.go @@ -58,6 +58,8 @@ func Error(rw http.ResponseWriter, cause error) { httpCode = http.StatusUnauthorized case errors.TypeUnsupported: httpCode = http.StatusNotImplemented + case errors.TypeForbidden: + httpCode = http.StatusForbidden } rea := make([]responseerroradditional, len(a)) diff --git a/pkg/modules/organization/implorganization/api.go b/pkg/modules/organization/implorganization/api.go index f899428946..fdbacb3127 100644 --- a/pkg/modules/organization/implorganization/api.go +++ b/pkg/modules/organization/implorganization/api.go @@ -21,11 +21,12 @@ func NewAPI(module organization.Module) organization.API { } func (api *organizationAPI) Get(rw http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } + orgID, err := valuer.NewUUID(claims.OrgID) if err != nil { render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id")) @@ -52,11 +53,12 @@ func (api *organizationAPI) GetAll(rw http.ResponseWriter, r *http.Request) { } func (api *organizationAPI) Update(rw http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } + orgID, err := valuer.NewUUID(claims.OrgID) if err != nil { render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id")) diff --git a/pkg/modules/preference/api.go b/pkg/modules/preference/api.go index 9bf704875b..d02b682038 100644 --- a/pkg/modules/preference/api.go +++ b/pkg/modules/preference/api.go @@ -4,7 +4,6 @@ import ( "encoding/json" "net/http" - errorsV2 "github.com/SigNoz/signoz/pkg/errors" "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/preferencetypes" @@ -30,9 +29,9 @@ func NewAPI(usecase Usecase) API { func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request) { preferenceId := mux.Vars(r)["preferenceId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } preference, err := p.usecase.GetOrgPreference( @@ -49,18 +48,18 @@ func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Request) { preferenceId := mux.Vars(r)["preferenceId"] req := preferencetypes.UpdatablePreference{} - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) - return - } - - err := json.NewDecoder(r.Body).Decode(&req) - + claims, err := authtypes.ClaimsFromContext(r.Context()) if err != nil { render.Error(rw, err) return } + + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + render.Error(rw, err) + return + } + err = p.usecase.UpdateOrgPreference(r.Context(), preferenceId, req.PreferenceValue, claims.OrgID) if err != nil { render.Error(rw, err) @@ -71,9 +70,9 @@ func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Requ } func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } preferences, err := p.usecase.GetAllOrgPreferences( @@ -89,9 +88,9 @@ func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Req func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Request) { preferenceId := mux.Vars(r)["preferenceId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } @@ -108,14 +107,14 @@ func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Reques func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Request) { preferenceId := mux.Vars(r)["preferenceId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } req := preferencetypes.UpdatablePreference{} - err := json.NewDecoder(r.Body).Decode(&req) + err = json.NewDecoder(r.Body).Decode(&req) if err != nil { render.Error(rw, err) @@ -131,9 +130,9 @@ func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Req } func (p *preferenceAPI) GetAllUserPreferences(rw http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, err := authtypes.ClaimsFromContext(r.Context()) + if err != nil { + render.Error(rw, err) return } preferences, err := p.usecase.GetAllUserPreferences( diff --git a/pkg/query-service/app/apdex.go b/pkg/query-service/app/apdex.go index a3e99e1744..d4d3dc4615 100644 --- a/pkg/query-service/app/apdex.go +++ b/pkg/query-service/app/apdex.go @@ -1,19 +1,19 @@ package app import ( - "errors" "net/http" "strings" + "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/types/authtypes" ) func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - RespondError(w, &model.ApiError{Err: errors.New("unauthorized"), Typ: model.ErrorUnauthorized}, nil) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } req, err := parseSetApdexScoreRequest(r) @@ -31,9 +31,9 @@ func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) getApdexSettings(w http.ResponseWriter, r *http.Request) { services := r.URL.Query().Get("services") - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - RespondError(w, &model.ApiError{Err: errors.New("unauthorized"), Typ: model.ErrorUnauthorized}, nil) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } apdexSet, err := dao.DB().GetApdexSettings(r.Context(), claims.OrgID, strings.Split(strings.TrimSpace(services), ",")) diff --git a/pkg/query-service/app/auth.go b/pkg/query-service/app/auth.go deleted file mode 100644 index 196cf65609..0000000000 --- a/pkg/query-service/app/auth.go +++ /dev/null @@ -1,123 +0,0 @@ -package app - -import ( - "context" - "errors" - "net/http" - - "github.com/SigNoz/signoz/pkg/query-service/auth" - "github.com/SigNoz/signoz/pkg/query-service/constants" - "github.com/SigNoz/signoz/pkg/query-service/model" - "github.com/SigNoz/signoz/pkg/types" - "github.com/gorilla/mux" -) - -type AuthMiddleware struct { - GetUserFromRequest func(r context.Context) (*types.GettableUser, error) -} - -func NewAuthMiddleware(f func(ctx context.Context) (*types.GettableUser, error)) *AuthMiddleware { - return &AuthMiddleware{ - GetUserFromRequest: f, - } -} - -func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - f(w, r) - } -} - -func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - - if !(auth.IsViewer(user) || auth.IsEditor(user) || auth.IsAdmin(user)) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible to viewers/editors/admins"), - }, nil) - return - } - ctx := context.WithValue(r.Context(), constants.ContextUserKey, user) - r = r.WithContext(ctx) - f(w, r) - } -} - -func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - if !(auth.IsEditor(user) || auth.IsAdmin(user)) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible to editors/admins"), - }, nil) - return - } - ctx := context.WithValue(r.Context(), constants.ContextUserKey, user) - r = r.WithContext(ctx) - f(w, r) - } -} - -func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - id := mux.Vars(r)["id"] - if !(auth.IsSelfAccessRequest(user, id) || auth.IsAdmin(user)) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible for self access or to the admins"), - }, nil) - return - } - ctx := context.WithValue(r.Context(), constants.ContextUserKey, user) - r = r.WithContext(ctx) - f(w, r) - } -} - -func (am *AuthMiddleware) AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := am.GetUserFromRequest(r.Context()) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - if !auth.IsAdmin(user) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible to admins only"), - }, nil) - return - } - ctx := context.WithValue(r.Context(), constants.ContextUserKey, user) - r = r.WithContext(ctx) - f(w, r) - } -} diff --git a/pkg/query-service/app/clickhouseReader/reader.go b/pkg/query-service/app/clickhouseReader/reader.go index 90cd2940ca..2522d7487c 100644 --- a/pkg/query-service/app/clickhouseReader/reader.go +++ b/pkg/query-service/app/clickhouseReader/reader.go @@ -1037,7 +1037,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con var serviceNameIntervalMap = map[string][]tracedetail.Interval{} var hasMissingSpans bool - claims, claimsPresent := authtypes.ClaimsFromContext(ctx) + claims, errv2 := authtypes.ClaimsFromContext(ctx) cachedTraceData, err := r.GetWaterfallSpansForTraceWithMetadataCache(ctx, traceID) if err == nil { startTime = cachedTraceData.StartTime @@ -1050,7 +1050,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con totalErrorSpans = cachedTraceData.TotalErrorSpans hasMissingSpans = cachedTraceData.HasMissingSpans - if claimsPresent { + if errv2 == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false) } } @@ -1067,7 +1067,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con } totalSpans = uint64(len(searchScanResponses)) - if claimsPresent { + if errv2 == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false) } @@ -3280,8 +3280,8 @@ func (r *ClickHouseReader) GetLogs(ctx context.Context, params *model.LogsFilter "lenFilters": lenFilters, } if lenFilters != 0 { - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) } } @@ -3322,8 +3322,8 @@ func (r *ClickHouseReader) TailLogs(ctx context.Context, client *model.LogsTailC "lenFilters": lenFilters, } if lenFilters != 0 { - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) } } @@ -3414,8 +3414,8 @@ func (r *ClickHouseReader) AggregateLogs(ctx context.Context, params *model.Logs "lenFilters": lenFilters, } if lenFilters != 0 { - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) } } @@ -6835,8 +6835,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea if traceSummary.NumSpans > uint64(params.MaxSpansInTrace) { zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace), zap.Uint64("Count", traceSummary.NumSpans)) - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { data := map[string]interface{}{ "traceSize": traceSummary.NumSpans, "maxSpansInTraceLimit": params.MaxSpansInTrace, @@ -6847,8 +6847,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details") } - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { data := map[string]interface{}{ "traceSize": traceSummary.NumSpans, "algo": "smart", @@ -6937,8 +6937,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea } end = time.Now() zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start))) - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { data := map[string]interface{}{ "traceSize": len(searchScanResponses), "spansRenderLimit": params.SpansRenderLimit, @@ -6976,8 +6976,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc if countSpans > uint64(params.MaxSpansInTrace) { zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace), zap.Uint64("Count", countSpans)) - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { data := map[string]interface{}{ "traceSize": countSpans, "maxSpansInTraceLimit": params.MaxSpansInTrace, @@ -6988,8 +6988,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details") } - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { data := map[string]interface{}{ "traceSize": countSpans, "algo": "smart", @@ -7049,8 +7049,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc } end = time.Now() zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start))) - claims, ok := authtypes.ClaimsFromContext(ctx) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 == nil { data := map[string]interface{}{ "traceSize": len(searchScanResponses), "spansRenderLimit": params.SpansRenderLimit, diff --git a/pkg/query-service/app/cloudintegrations/controller_test.go b/pkg/query-service/app/cloudintegrations/controller_test.go index d5e4106fa4..80e01f0bd0 100644 --- a/pkg/query-service/app/cloudintegrations/controller_test.go +++ b/pkg/query-service/app/cloudintegrations/controller_test.go @@ -6,12 +6,11 @@ import ( "github.com/SigNoz/signoz/pkg/modules/organization" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization" - "github.com/SigNoz/signoz/pkg/query-service/auth" - "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/google/uuid" "github.com/stretchr/testify/require" ) @@ -300,13 +299,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model return nil, model.InternalError(err) } - group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup) - if apiErr != nil { - return nil, model.InternalError(apiErr) - } - - auth.InitAuthCache(ctx) - userId := uuid.NewString() return dao.DB().CreateUser( ctx, @@ -316,7 +308,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model Email: userId[:8] + "test@test.com", Password: "test", OrgID: organization.ID.StringValue(), - GroupID: group.ID, + Role: authtypes.RoleAdmin.String(), }, true, ) diff --git a/pkg/query-service/app/explorer/db.go b/pkg/query-service/app/explorer/db.go index 461567f042..a5b3f4bc19 100644 --- a/pkg/query-service/app/explorer/db.go +++ b/pkg/query-service/app/explorer/db.go @@ -108,8 +108,8 @@ func CreateView(ctx context.Context, orgID string, view v3.SavedView) (valuer.UU createdAt := time.Now() updatedAt := time.Now() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { return valuer.UUID{}, fmt.Errorf("error in getting email from context") } @@ -177,8 +177,8 @@ func UpdateView(ctx context.Context, orgID string, uuid valuer.UUID, view v3.Sav return fmt.Errorf("error in marshalling explorer query data: %s", err.Error()) } - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { return fmt.Errorf("error in getting email from context") } diff --git a/pkg/query-service/app/http_handler.go b/pkg/query-service/app/http_handler.go index a26ce2d4d2..52d7ff0d87 100644 --- a/pkg/query-service/app/http_handler.go +++ b/pkg/query-service/app/http_handler.go @@ -21,6 +21,7 @@ import ( "github.com/SigNoz/signoz/pkg/alertmanager" "github.com/SigNoz/signoz/pkg/apis/fields" errorsV2 "github.com/SigNoz/signoz/pkg/errors" + "github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/modules/organization" "github.com/SigNoz/signoz/pkg/modules/preference" @@ -54,7 +55,6 @@ import ( tracesV4 "github.com/SigNoz/signoz/pkg/query-service/app/traces/v4" "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/cache" - "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/contextlinks" v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" "github.com/SigNoz/signoz/pkg/query-service/postprocess" @@ -405,7 +405,7 @@ func writeHttpResponse(w http.ResponseWriter, data interface{}) { } } -func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/api/v3").Subrouter() subRouter.HandleFunc("/autocomplete/aggregate_attributes", am.ViewAccess( withCacheControl(AutoCompleteCacheControlAge, aH.autocompleteAggregateAttributes))).Methods(http.MethodGet) @@ -431,14 +431,14 @@ func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMid subRouter.HandleFunc("/logs/livetail", am.ViewAccess(aH.liveTailLogs)).Methods(http.MethodGet) } -func (aH *APIHandler) RegisterFieldsRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterFieldsRoutes(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/api/v1").Subrouter() subRouter.HandleFunc("/fields/keys", am.ViewAccess(aH.FieldsAPI.GetFieldsKeys)).Methods(http.MethodGet) subRouter.HandleFunc("/fields/values", am.ViewAccess(aH.FieldsAPI.GetFieldsValues)).Methods(http.MethodGet) } -func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *middleware.AuthZ) { hostsSubRouter := router.PathPrefix("/api/v1/hosts").Subrouter() hostsSubRouter.HandleFunc("/attribute_keys", am.ViewAccess(aH.getHostAttributeKeys)).Methods(http.MethodGet) hostsSubRouter.HandleFunc("/attribute_values", am.ViewAccess(aH.getHostAttributeValues)).Methods(http.MethodGet) @@ -498,12 +498,12 @@ func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *AuthMid infraOnboardingSubRouter.HandleFunc("/k8s/status", am.ViewAccess(aH.getK8sInfraOnboardingStatus)).Methods(http.MethodGet) } -func (aH *APIHandler) RegisterWebSocketPaths(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterWebSocketPaths(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/ws").Subrouter() subRouter.HandleFunc("/query_progress", am.ViewAccess(aH.GetQueryProgressUpdates)).Methods(http.MethodGet) } -func (aH *APIHandler) RegisterQueryRangeV4Routes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterQueryRangeV4Routes(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/api/v4").Subrouter() subRouter.HandleFunc("/query_range", am.ViewAccess(aH.QueryRangeV4)).Methods(http.MethodPost) subRouter.HandleFunc("/metric/metric_metadata", am.ViewAccess(aH.getMetricMetadata)).Methods(http.MethodGet) @@ -520,7 +520,7 @@ func (aH *APIHandler) RegisterPrivateRoutes(router *mux.Router) { } // RegisterRoutes registers routes for this handler on the given router -func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *middleware.AuthZ) { router.HandleFunc("/api/v1/query_range", am.ViewAccess(aH.queryRangeMetrics)).Methods(http.MethodGet) router.HandleFunc("/api/v1/query", am.ViewAccess(aH.queryMetrics)).Methods(http.MethodGet) router.HandleFunc("/api/v1/channels", am.ViewAccess(aH.AlertmanagerAPI.ListChannels)).Methods(http.MethodGet) @@ -649,7 +649,7 @@ func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) { })).Methods(http.MethodGet) } -func (ah *APIHandler) MetricExplorerRoutes(router *mux.Router, am *AuthMiddleware) { +func (ah *APIHandler) MetricExplorerRoutes(router *mux.Router, am *middleware.AuthZ) { router.HandleFunc("/api/v1/metrics/filters/keys", am.ViewAccess(ah.FilterKeysSuggestion)). Methods(http.MethodGet) @@ -757,9 +757,9 @@ func (aH *APIHandler) PopulateTemporality(ctx context.Context, qp *v3.QueryRange } func (aH *APIHandler) listDowntimeSchedules(w http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -1119,10 +1119,9 @@ func (aH *APIHandler) listRules(w http.ResponseWriter, r *http.Request) { } func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) { - - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } allDashboards, err := dashboards.GetDashboards(r.Context(), claims.OrgID) @@ -1189,11 +1188,10 @@ func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) { } func (aH *APIHandler) deleteDashboard(w http.ResponseWriter, r *http.Request) { - uuid := mux.Vars(r)["uuid"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } err := dashboards.DeleteDashboard(r.Context(), claims.OrgID, uuid) @@ -1268,7 +1266,6 @@ func (aH *APIHandler) queryDashboardVarsV2(w http.ResponseWriter, r *http.Reques } func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) { - uuid := mux.Vars(r)["uuid"] var postData map[string]interface{} @@ -1283,9 +1280,9 @@ func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) { return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } dashboard, apiError := dashboards.UpdateDashboard(r.Context(), claims.OrgID, claims.Email, uuid, postData) @@ -1302,9 +1299,9 @@ func (aH *APIHandler) getDashboard(w http.ResponseWriter, r *http.Request) { uuid := mux.Vars(r)["uuid"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } dashboard, apiError := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid) @@ -1356,9 +1353,9 @@ func (aH *APIHandler) createDashboards(w http.ResponseWriter, r *http.Request) { RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, "Error reading request body") return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } dash, apiErr := dashboards.CreateDashboard(r.Context(), claims.OrgID, claims.Email, postData) @@ -1612,20 +1609,19 @@ func (aH *APIHandler) submitFeedback(w http.ResponseWriter, r *http.Request) { "email": email, "message": message, } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 == nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_INPRODUCT_FEEDBACK, data, claims.Email, true, false) } } func (aH *APIHandler) registerEvent(w http.ResponseWriter, r *http.Request) { - request, err := parseRegisterEventRequest(r) if aH.HandleError(w, err, http.StatusBadRequest) { return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 == nil { switch request.EventType { case model.TrackEvent: telemetry.GetInstance().SendEvent(request.EventName, request.Attributes, claims.Email, request.RateLimited, true) @@ -1737,8 +1733,8 @@ func (aH *APIHandler) getServices(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "number": len(*result), } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_NUMBER_OF_SERVICES, data, claims.Email, true, false) } @@ -1918,8 +1914,8 @@ func (aH *APIHandler) setTTL(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil) return } @@ -1946,8 +1942,8 @@ func (aH *APIHandler) getTTL(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil) return } @@ -2030,9 +2026,10 @@ func (aH *APIHandler) inviteUser(w http.ResponseWriter, r *http.Request) { resp, err := auth.Invite(r.Context(), req) if err != nil { - RespondError(w, &model.ApiError{Err: err, Typ: model.ErrorInternal}, nil) + render.Error(w, err) return } + aH.WriteJSON(w, r, resp) } @@ -2088,11 +2085,10 @@ func (aH *APIHandler) revokeInvite(w http.ResponseWriter, r *http.Request) { // listPendingInvites is used to list the pending invites. func (aH *APIHandler) listPendingInvites(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil) + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { + render.Error(w, errv2) return } invites, err := dao.DB().GetInvites(ctx, claims.OrgID) @@ -2309,18 +2305,13 @@ func (aH *APIHandler) deleteUser(w http.ResponseWriter, r *http.Request) { return } - adminGroup, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup) - if apiErr != nil { - RespondError(w, apiErr, "Failed to get admin group") - return - } - adminUsers, apiErr := dao.DB().GetUsersByGroup(ctx, adminGroup.ID) + adminUsers, apiErr := dao.DB().GetUsersByRole(ctx, authtypes.RoleAdmin) if apiErr != nil { RespondError(w, apiErr, "Failed to get admin group users") return } - if user.GroupID == adminGroup.ID && len(adminUsers) == 1 { + if user.Role == authtypes.RoleAdmin.String() && len(adminUsers) == 1 { RespondError(w, &model.ApiError{ Typ: model.ErrorInternal, Err: errors.New("cannot delete the last admin user")}, nil) @@ -2332,6 +2323,7 @@ func (aH *APIHandler) deleteUser(w http.ResponseWriter, r *http.Request) { RespondError(w, err, "Failed to delete user") return } + aH.WriteJSON(w, r, map[string]string{"data": "user deleted successfully"}) } @@ -2350,13 +2342,8 @@ func (aH *APIHandler) getRole(w http.ResponseWriter, r *http.Request) { }, nil) return } - group, err := dao.DB().GetGroup(context.Background(), user.GroupID) - if err != nil { - RespondError(w, err, "Failed to get group") - return - } - aH.WriteJSON(w, r, &model.UserRole{UserId: id, GroupName: group.Name}) + aH.WriteJSON(w, r, &model.UserRole{UserId: id, GroupName: user.Role}) } func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) { @@ -2368,14 +2355,9 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) { } ctx := context.Background() - newGroup, apiErr := dao.DB().GetGroupByName(ctx, req.GroupName) - if apiErr != nil { - RespondError(w, apiErr, "Failed to get user's group") - return - } - - if newGroup == nil { - RespondError(w, apiErr, "Specified group is not present") + role, err := authtypes.NewRole(req.GroupName) + if err != nil { + RespondError(w, &model.ApiError{Typ: model.ErrorBadData, Err: errors.New("invalid role")}, nil) return } @@ -2386,8 +2368,8 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) { } // Make sure that the request is not demoting the last admin user. - if user.GroupID == auth.AuthCacheObj.AdminGroupId { - adminUsers, apiErr := dao.DB().GetUsersByGroup(ctx, auth.AuthCacheObj.AdminGroupId) + if user.Role == authtypes.RoleAdmin.String() { + adminUsers, apiErr := dao.DB().GetUsersByRole(ctx, authtypes.RoleAdmin) if apiErr != nil { RespondError(w, apiErr, "Failed to fetch adminUsers") return @@ -2401,7 +2383,7 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) { } } - apiErr = dao.DB().UpdateUserGroup(context.Background(), user.ID, newGroup.ID) + apiErr = dao.DB().UpdateUserRole(context.Background(), user.ID, role) if apiErr != nil { RespondError(w, apiErr, "Failed to add user to group") return @@ -2525,7 +2507,7 @@ func (aH *APIHandler) WriteJSON(w http.ResponseWriter, r *http.Request, response } // RegisterMessagingQueuesRoutes adds messaging-queues routes -func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *middleware.AuthZ) { // Main messaging queues router messagingQueuesRouter := router.PathPrefix("/api/v1/messaging-queues").Subrouter() @@ -2567,7 +2549,7 @@ func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *Auth } // RegisterThirdPartyApiRoutes adds third-party-api integration routes -func (aH *APIHandler) RegisterThirdPartyApiRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterThirdPartyApiRoutes(router *mux.Router, am *middleware.AuthZ) { // Main messaging queues router thirdPartyApiRouter := router.PathPrefix("/api/v1/third-party-apis").Subrouter() @@ -3494,7 +3476,7 @@ func (aH *APIHandler) getAllOrgPreferences( } // RegisterIntegrationRoutes Registers all Integrations -func (aH *APIHandler) RegisterIntegrationRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterIntegrationRoutes(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/api/v1/integrations").Subrouter() subRouter.HandleFunc( @@ -3526,9 +3508,9 @@ func (aH *APIHandler) ListIntegrations( for k, values := range r.URL.Query() { params[k] = values[0] } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3546,9 +3528,9 @@ func (aH *APIHandler) GetIntegration( w http.ResponseWriter, r *http.Request, ) { integrationId := mux.Vars(r)["integrationId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } integration, apiErr := aH.IntegrationsController.GetIntegration( @@ -3566,9 +3548,9 @@ func (aH *APIHandler) GetIntegrationConnectionStatus( w http.ResponseWriter, r *http.Request, ) { integrationId := mux.Vars(r)["integrationId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } isInstalled, apiErr := aH.IntegrationsController.IsIntegrationInstalled( @@ -3785,9 +3767,9 @@ func (aH *APIHandler) InstallIntegration( return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3813,9 +3795,9 @@ func (aH *APIHandler) UninstallIntegration( return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3829,7 +3811,7 @@ func (aH *APIHandler) UninstallIntegration( } // cloud provider integrations -func (aH *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/api/v1/cloud-integrations").Subrouter() subRouter.HandleFunc( @@ -3875,9 +3857,9 @@ func (aH *APIHandler) CloudIntegrationsListConnectedAccounts( ) { cloudProvider := mux.Vars(r)["cloudProvider"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3903,9 +3885,9 @@ func (aH *APIHandler) CloudIntegrationsGenerateConnectionUrl( return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3927,9 +3909,9 @@ func (aH *APIHandler) CloudIntegrationsGetAccountStatus( cloudProvider := mux.Vars(r)["cloudProvider"] accountId := mux.Vars(r)["accountId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3955,9 +3937,9 @@ func (aH *APIHandler) CloudIntegrationsAgentCheckIn( return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -3985,9 +3967,9 @@ func (aH *APIHandler) CloudIntegrationsUpdateAccountConfig( return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4009,9 +3991,9 @@ func (aH *APIHandler) CloudIntegrationsDisconnectAccount( cloudProvider := mux.Vars(r)["cloudProvider"] accountId := mux.Vars(r)["accountId"] - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4039,9 +4021,9 @@ func (aH *APIHandler) CloudIntegrationsListServices( cloudAccountId = &cloudAccountIdQP } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4069,9 +4051,9 @@ func (aH *APIHandler) CloudIntegrationsGetServiceDetails( cloudAccountId = &cloudAccountIdQP } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4315,9 +4297,9 @@ func (aH *APIHandler) CloudIntegrationsUpdateServiceConfig( return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4334,7 +4316,7 @@ func (aH *APIHandler) CloudIntegrationsUpdateServiceConfig( } // logs -func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *AuthMiddleware) { +func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *middleware.AuthZ) { subRouter := router.PathPrefix("/api/v1/logs").Subrouter() subRouter.HandleFunc("", am.ViewAccess(aH.getLogs)).Methods(http.MethodGet) subRouter.HandleFunc("/tail", am.ViewAccess(aH.tailLogs)).Methods(http.MethodGet) @@ -4498,9 +4480,9 @@ func (aH *APIHandler) PreviewLogsPipelinesHandler(w http.ResponseWriter, r *http } func (aH *APIHandler) ListLogsPipelinesHandler(w http.ResponseWriter, r *http.Request) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4577,10 +4559,9 @@ func (aH *APIHandler) listLogsPipelinesByVersion(ctx context.Context, orgID stri } func (aH *APIHandler) CreateLogsPipeline(w http.ResponseWriter, r *http.Request) { - - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } @@ -4622,11 +4603,12 @@ func (aH *APIHandler) getSavedViews(w http.ResponseWriter, r *http.Request) { name := r.URL.Query().Get("name") category := r.URL.Query().Get("category") - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } + queries, err := explorer.GetViewsForFilters(r.Context(), claims.OrgID, sourcePage, name, category) if err != nil { RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, nil) @@ -4648,9 +4630,9 @@ func (aH *APIHandler) createSavedViews(w http.ResponseWriter, r *http.Request) { return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } uuid, err := explorer.CreateView(r.Context(), claims.OrgID, view) @@ -4670,9 +4652,9 @@ func (aH *APIHandler) getSavedView(w http.ResponseWriter, r *http.Request) { return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } view, err := explorer.GetView(r.Context(), claims.OrgID, viewUUID) @@ -4703,9 +4685,9 @@ func (aH *APIHandler) updateSavedView(w http.ResponseWriter, r *http.Request) { return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } err = explorer.UpdateView(r.Context(), claims.OrgID, viewUUID, view) @@ -4725,9 +4707,9 @@ func (aH *APIHandler) deleteSavedView(w http.ResponseWriter, r *http.Request) { render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, err.Error())) return } - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if !ok { - render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 != nil { + render.Error(w, errv2) return } err = explorer.DeleteView(r.Context(), claims.OrgID, viewUUID) @@ -5050,8 +5032,8 @@ func sendQueryResultEvents(r *http.Request, result []*v3.Result, queryRangeParam if len(result) > 0 && (len(result[0].Series) > 0 || len(result[0].List) > 0) { - claims, ok := authtypes.ClaimsFromContext(r.Context()) - if ok { + claims, errv2 := authtypes.ClaimsFromContext(r.Context()) + if errv2 == nil { queryInfoResult := telemetry.GetInstance().CheckQueryInfo(queryRangeParams) if queryInfoResult.LogsUsed || queryInfoResult.MetricsUsed || queryInfoResult.TracesUsed { diff --git a/pkg/query-service/app/integrations/test_utils.go b/pkg/query-service/app/integrations/test_utils.go index 7494961ff4..11a67fbd9b 100644 --- a/pkg/query-service/app/integrations/test_utils.go +++ b/pkg/query-service/app/integrations/test_utils.go @@ -6,14 +6,13 @@ import ( "testing" "github.com/SigNoz/signoz/pkg/modules/organization" - "github.com/SigNoz/signoz/pkg/query-service/auth" - "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/model" v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" "github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/pipelinetypes" ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes" "github.com/google/uuid" @@ -42,13 +41,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model return nil, model.InternalError(err) } - group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup) - if apiErr != nil { - return nil, model.InternalError(apiErr) - } - - auth.InitAuthCache(ctx) - userId := uuid.NewString() return dao.DB().CreateUser( ctx, @@ -58,7 +50,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model Email: userId[:8] + "test@test.com", Password: "test", OrgID: organization.ID.StringValue(), - GroupID: group.ID, + Role: authtypes.RoleAdmin.String(), }, true, ) diff --git a/pkg/query-service/app/logparsingpipeline/controller.go b/pkg/query-service/app/logparsingpipeline/controller.go index 4b66fbb61c..cc077e66e0 100644 --- a/pkg/query-service/app/logparsingpipeline/controller.go +++ b/pkg/query-service/app/logparsingpipeline/controller.go @@ -54,8 +54,8 @@ func (ic *LogParsingPipelineController) ApplyPipelines( postable []pipelinetypes.PostablePipeline, ) (*PipelinesResponse, *model.ApiError) { // get user id from context - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { return nil, model.UnauthorizedError(fmt.Errorf("failed to get userId from context")) } diff --git a/pkg/query-service/app/logparsingpipeline/db.go b/pkg/query-service/app/logparsingpipeline/db.go index 6703c62035..719c2e2069 100644 --- a/pkg/query-service/app/logparsingpipeline/db.go +++ b/pkg/query-service/app/logparsingpipeline/db.go @@ -53,8 +53,8 @@ func (r *Repo) insertPipeline( )) } - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { return nil, model.UnauthorizedError(fmt.Errorf("failed to get email from context")) } diff --git a/pkg/query-service/app/metricsexplorer/summary.go b/pkg/query-service/app/metricsexplorer/summary.go index f58f6491d4..807a1d19d5 100644 --- a/pkg/query-service/app/metricsexplorer/summary.go +++ b/pkg/query-service/app/metricsexplorer/summary.go @@ -159,9 +159,9 @@ func (receiver *SummaryService) GetMetricsSummary(ctx context.Context, metricNam g.Go(func() error { var metricNames []string metricNames = append(metricNames, metricName) - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return &model.ApiError{Typ: model.ErrorInternal, Err: errors.New("failed to get claims")} + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { + return &model.ApiError{Typ: model.ErrorInternal, Err: errv2} } data, err := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames) if err != nil { @@ -332,9 +332,9 @@ func (receiver *SummaryService) GetRelatedMetrics(ctx context.Context, params *m alertsRelatedData := make(map[string][]metrics_explorer.Alert) g.Go(func() error { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return &model.ApiError{Typ: model.ErrorInternal, Err: errors.New("failed to get claims")} + claims, errv2 := authtypes.ClaimsFromContext(ctx) + if errv2 != nil { + return &model.ApiError{Typ: model.ErrorInternal, Err: errv2} } names, apiError := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames) if apiError != nil { diff --git a/pkg/query-service/app/parser.go b/pkg/query-service/app/parser.go index cedaa1dde3..c6907400cc 100644 --- a/pkg/query-service/app/parser.go +++ b/pkg/query-service/app/parser.go @@ -16,6 +16,7 @@ import ( "github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/kafka" queues2 "github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/queues" "github.com/SigNoz/signoz/pkg/query-service/app/integrations/thirdPartyApi" + "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/govaluate" "github.com/gorilla/mux" @@ -27,7 +28,6 @@ import ( "github.com/SigNoz/signoz/pkg/query-service/app/queryBuilder" "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/common" - "github.com/SigNoz/signoz/pkg/query-service/constants" baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/model" v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" @@ -492,14 +492,6 @@ func parseInviteRequest(r *http.Request) (*model.InviteRequest, error) { return &req, nil } -func isValidRole(role string) bool { - switch role { - case constants.AdminGroup, constants.EditorGroup, constants.ViewerGroup: - return true - } - return false -} - func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error) { var req model.BulkInviteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -520,7 +512,9 @@ func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error) if req.Users[i].FrontendBaseUrl == "" { return nil, fmt.Errorf("frontendBaseUrl is required for each user") } - if !isValidRole(req.Users[i].Role) { + + _, err := authtypes.NewRole(req.Users[i].Role) + if err != nil { return nil, fmt.Errorf("invalid role for user: %s", req.Users[i].Email) } } diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index 8cb3c46dda..2d0034f29e 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -2,7 +2,6 @@ package app import ( "context" - "errors" "fmt" "net" "net/http" @@ -30,7 +29,6 @@ import ( "github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/telemetrystore" - "github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/preferencetypes" "github.com/SigNoz/signoz/pkg/web" @@ -38,7 +36,6 @@ import ( "github.com/soheilhy/cmux" "github.com/SigNoz/signoz/pkg/query-service/app/explorer" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/cache" "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/dao" @@ -308,21 +305,7 @@ func (s *Server) createPublicServer(api *APIHandler, web web.Web) (*http.Server, r.Use(middleware.NewAnalytics(zap.L()).Wrap) r.Use(middleware.NewLogging(zap.L(), s.serverOptions.Config.APIServer.Logging.ExcludedRoutes).Wrap) - // add auth middleware - getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) { - user, err := auth.GetUserFromReqContext(ctx) - - if err != nil { - return nil, err - } - - if user.User.OrgID == "" { - return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims")) - } - - return user, nil - } - am := NewAuthMiddleware(getUserFromRequest) + am := middleware.NewAuthZ(s.serverOptions.SigNoz.Instrumentation.Logger()) api.RegisterRoutes(r, am) api.RegisterLogsRoutes(r, am) diff --git a/pkg/query-service/auth/auth.go b/pkg/query-service/auth/auth.go index 10304732a6..4af12fb1af 100644 --- a/pkg/query-service/auth/auth.go +++ b/pkg/query-service/auth/auth.go @@ -26,18 +26,17 @@ import ( "golang.org/x/crypto/bcrypt" ) -type JwtContextKeyType string - -const AccessJwtKey JwtContextKeyType = "accessJwt" -const RefreshJwtKey JwtContextKeyType = "refreshJwt" - const ( opaqueTokenSize = 16 minimumPasswordLength = 8 ) var ( - ErrorInvalidCreds = fmt.Errorf("invalid credentials") + ErrorInvalidCreds = fmt.Errorf("invalid credentials") + ErrorEmptyRequest = errors.New("Empty request") + ErrorInvalidRole = errors.New("Invalid role") + ErrorInvalidInviteToken = errors.New("Invalid invite token") + ErrorAskAdmin = errors.New("An invitation is needed to create an account. Please ask your admin (the person who has first installed SIgNoz) to send an invite.") ) type InviteEmailData struct { @@ -49,7 +48,10 @@ type InviteEmailData struct { // The root user should be able to invite people to create account on SigNoz cluster. func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteResponse, error) { - zap.L().Debug("Got an invite request for email", zap.String("email", req.Email)) + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return nil, err + } token, err := utils.RandomHex(opaqueTokenSize) if err != nil { @@ -64,11 +66,6 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons if user != nil { return nil, errors.New("User already exists with the same email") } - - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, errors.New("failed to extract OrgID from context") - } // Check if an invite already exists invite, apiErr := dao.DB().GetInviteFromEmail(ctx, req.Email) if apiErr != nil { @@ -79,8 +76,9 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons return nil, errors.New("An invite already exists for this email") } - if err := validateInviteRequest(req); err != nil { - return nil, errors.Wrap(err, "invalid invite request") + role, err := authtypes.NewRole(req.Role) + if err != nil { + return nil, err } au, apiErr := dao.DB().GetUser(ctx, claims.UserID) @@ -99,7 +97,7 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons Name: req.Name, Email: req.Email, Token: token, - Role: req.Role, + Role: role.String(), OrgID: au.OrgID, } @@ -120,6 +118,11 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons } func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.BulkInviteResponse, error) { + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return nil, err + } + response := &model.BulkInviteResponse{ Status: "success", Summary: model.InviteSummary{TotalInvites: len(req.Users)}, @@ -127,11 +130,6 @@ func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.Bulk FailedInvites: []model.FailedInvite{}, } - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, errors.New("failed to extract admin user id") - } - au, apiErr := dao.DB().GetUser(ctx, claims.UserID) if apiErr != nil { return nil, errors.Wrap(apiErr.Err, "failed to query admin user from the DB") @@ -191,8 +189,9 @@ func inviteUser(ctx context.Context, req *model.InviteRequest, au *types.Gettabl return nil, errors.New("An invite already exists for this email") } - if err := validateInviteRequest(req); err != nil { - return nil, errors.Wrap(err, "invalid invite request") + role, err := authtypes.NewRole(req.Role) + if err != nil { + return nil, err } inv := &types.Invite{ @@ -206,7 +205,7 @@ func inviteUser(ctx context.Context, req *model.InviteRequest, au *types.Gettabl Name: req.Name, Email: req.Email, Token: token, - Role: req.Role, + Role: role.String(), OrgID: au.OrgID, } @@ -260,15 +259,9 @@ func inviteEmail(req *model.InviteRequest, au *types.GettableUser, token string) // RevokeInvite is used to revoke the invitation for the given email. func RevokeInvite(ctx context.Context, email string) error { - zap.L().Debug("RevokeInvite method invoked for email", zap.String("email", email)) - - if !isValidEmail(email) { - return ErrorInvalidInviteToken - } - - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return errors.New("failed to org id from context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return err } if err := dao.DB().DeleteInvitation(ctx, claims.OrgID, email); err != nil { @@ -414,19 +407,12 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest, organizationMo return nil, model.BadRequest(model.ErrPasswordRequired{}) } - groupName := constants.AdminGroup organization := types.NewOrganization(req.OrgDisplayName) err := organizationModule.Create(ctx, organization) if err != nil { return nil, model.InternalError(err) } - group, apiErr := dao.DB().GetGroupByName(ctx, groupName) - if apiErr != nil { - zap.L().Error("GetGroupByName failed", zap.Error(apiErr.Err)) - return nil, apiErr - } - var hash string hash, err = PasswordHash(req.Password) if err != nil { @@ -443,7 +429,7 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest, organizationMo CreatedAt: time.Now(), }, ProfilePictureURL: "", // Currently unused - GroupID: group.ID, + Role: authtypes.RoleAdmin.String(), OrgID: organization.ID.StringValue(), } @@ -488,13 +474,7 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b if invite.Role == "" { // if role is not provided, default to viewer - invite.Role = constants.ViewerGroup - } - - group, apiErr := dao.DB().GetGroupByName(ctx, invite.Role) - if apiErr != nil { - zap.L().Error("GetGroupByName failed", zap.Error(apiErr.Err)) - return nil, model.InternalError(model.ErrSignupFailed{}) + invite.Role = authtypes.RoleViewer.String() } var hash string @@ -523,12 +503,12 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b CreatedAt: time.Now(), }, ProfilePictureURL: "", // Currently unused - GroupID: group.ID, + Role: invite.Role, OrgID: invite.OrgID, } // TODO(Ahsan): Ideally create user and delete invitation should happen in a txn. - user, apiErr = dao.DB().CreateUser(ctx, user, false) + user, apiErr := dao.DB().CreateUser(ctx, user, false) if apiErr != nil { zap.L().Error("CreateUser failed", zap.Error(apiErr.Err)) return nil, apiErr @@ -574,8 +554,6 @@ func Register(ctx context.Context, req *RegisterRequest, alertmanager alertmanag // Login method returns access and refresh tokens on successful login, else it errors out. func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT) (*model.LoginResponse, error) { - zap.L().Debug("Login method called for user", zap.String("email", request.Email)) - user, err := authenticateLogin(ctx, request, jwt) if err != nil { zap.L().Error("Failed to authenticate login request", zap.Error(err)) @@ -599,18 +577,6 @@ func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT) }, nil } -func claimsToUserPayload(claims authtypes.Claims) (*types.GettableUser, error) { - user := &types.GettableUser{ - User: types.User{ - ID: claims.UserID, - GroupID: claims.GroupID, - Email: claims.Email, - OrgID: claims.OrgID, - }, - } - return user, nil -} - // authenticateLogin is responsible for querying the DB and validating the credentials. func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authtypes.JWT) (*types.GettableUser, error) { // If refresh token is valid, then simply authorize the login request. @@ -621,13 +587,13 @@ func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authty return nil, errors.Wrap(err, "failed to parse refresh token") } - if claims.OrgID == "" { - return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims")) - } - - user, err := claimsToUserPayload(claims) - if err != nil { - return nil, errors.Wrap(err, "failed to convert claims to user payload") + user := &types.GettableUser{ + User: types.User{ + ID: claims.UserID, + Role: claims.Role.String(), + Email: claims.Email, + OrgID: claims.OrgID, + }, } return user, nil } @@ -658,18 +624,32 @@ func passwordMatch(hash, password string) bool { } func GenerateJWTForUser(user *types.User, jwt *authtypes.JWT) (model.UserJwtObject, error) { - j := model.UserJwtObject{} - var err error - j.AccessJwtExpiry = time.Now().Add(jwt.JwtExpiry).Unix() - j.AccessJwt, err = jwt.AccessToken(user.OrgID, user.ID, user.GroupID, user.Email) + role, err := authtypes.NewRole(user.Role) if err != nil { - return j, errors.Errorf("failed to encode jwt: %v", err) + return model.UserJwtObject{}, err } - j.RefreshJwtExpiry = time.Now().Add(jwt.JwtRefresh).Unix() - j.RefreshJwt, err = jwt.RefreshToken(user.OrgID, user.ID, user.GroupID, user.Email) + accessJwt, accessClaims, err := jwt.AccessToken(user.OrgID, user.ID, user.Email, role) if err != nil { - return j, errors.Errorf("failed to encode jwt: %v", err) + return model.UserJwtObject{}, err } - return j, nil + + refreshJwt, refreshClaims, err := jwt.RefreshToken(user.OrgID, user.ID, user.Email, role) + if err != nil { + return model.UserJwtObject{}, err + } + + return model.UserJwtObject{ + AccessJwt: accessJwt, + RefreshJwt: refreshJwt, + AccessJwtExpiry: accessClaims.ExpiresAt.Unix(), + RefreshJwtExpiry: refreshClaims.ExpiresAt.Unix(), + }, nil +} + +func ValidatePassword(password string) error { + if len(password) < minimumPasswordLength { + return errors.Errorf("Password should be atleast %d characters.", minimumPasswordLength) + } + return nil } diff --git a/pkg/query-service/auth/rbac.go b/pkg/query-service/auth/rbac.go deleted file mode 100644 index 9ffc02276c..0000000000 --- a/pkg/query-service/auth/rbac.go +++ /dev/null @@ -1,82 +0,0 @@ -package auth - -import ( - "context" - - errorsV2 "github.com/SigNoz/signoz/pkg/errors" - "github.com/SigNoz/signoz/pkg/query-service/constants" - "github.com/SigNoz/signoz/pkg/query-service/dao" - "github.com/SigNoz/signoz/pkg/types" - "github.com/SigNoz/signoz/pkg/types/authtypes" - "github.com/pkg/errors" -) - -type Group struct { - GroupID string - GroupName string -} - -type AuthCache struct { - AdminGroupId string - EditorGroupId string - ViewerGroupId string -} - -var AuthCacheObj AuthCache - -// InitAuthCache reads the DB and initialize the auth cache. -func InitAuthCache(ctx context.Context) error { - - setGroupId := func(groupName string, dest *string) error { - group, err := dao.DB().GetGroupByName(ctx, groupName) - if err != nil { - return errors.Wrapf(err.Err, "failed to get group %s", groupName) - } - *dest = group.ID - return nil - } - - if err := setGroupId(constants.AdminGroup, &AuthCacheObj.AdminGroupId); err != nil { - return err - } - if err := setGroupId(constants.EditorGroup, &AuthCacheObj.EditorGroupId); err != nil { - return err - } - if err := setGroupId(constants.ViewerGroup, &AuthCacheObj.ViewerGroupId); err != nil { - return err - } - - return nil -} - -func GetUserFromReqContext(ctx context.Context) (*types.GettableUser, error) { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, errorsV2.New(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, "no claims found in context") - } - - user := &types.GettableUser{ - User: types.User{ - ID: claims.UserID, - GroupID: claims.GroupID, - Email: claims.Email, - OrgID: claims.OrgID, - }, - } - return user, nil -} - -func IsSelfAccessRequest(user *types.GettableUser, id string) bool { return user.ID == id } - -func IsViewer(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.ViewerGroupId } -func IsEditor(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.EditorGroupId } -func IsAdmin(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.AdminGroupId } - -func IsAdminV2(claims authtypes.Claims) bool { return claims.GroupID == AuthCacheObj.AdminGroupId } - -func ValidatePassword(password string) error { - if len(password) < minimumPasswordLength { - return errors.Errorf("Password should be atleast %d characters.", minimumPasswordLength) - } - return nil -} diff --git a/pkg/query-service/auth/utils.go b/pkg/query-service/auth/utils.go deleted file mode 100644 index 5d0aaa3a3e..0000000000 --- a/pkg/query-service/auth/utils.go +++ /dev/null @@ -1,43 +0,0 @@ -package auth - -import ( - "github.com/SigNoz/signoz/pkg/query-service/constants" - "github.com/SigNoz/signoz/pkg/query-service/model" - "github.com/pkg/errors" -) - -var ( - ErrorEmptyRequest = errors.New("Empty request") - ErrorInvalidEmail = errors.New("Invalid email") - ErrorInvalidRole = errors.New("Invalid role") - - ErrorInvalidInviteToken = errors.New("Invalid invite token") - ErrorAskAdmin = errors.New("An invitation is needed to create an account. Please ask your admin (the person who has first installed SIgNoz) to send an invite.") -) - -func isValidRole(role string) bool { - switch role { - case constants.AdminGroup, constants.EditorGroup, constants.ViewerGroup: - return true - } - return false -} - -func validateInviteRequest(req *model.InviteRequest) error { - if req == nil { - return ErrorEmptyRequest - } - if !isValidEmail(req.Email) { - return ErrorInvalidEmail - } - - if !isValidRole(req.Role) { - return ErrorInvalidRole - } - return nil -} - -// TODO(Ahsan): Implement check on email semantic. -func isValidEmail(email string) bool { - return true -} diff --git a/pkg/query-service/constants/auth.go b/pkg/query-service/constants/auth.go deleted file mode 100644 index 7248f9dc1a..0000000000 --- a/pkg/query-service/constants/auth.go +++ /dev/null @@ -1,7 +0,0 @@ -package constants - -const ( - AdminGroup = "ADMIN" - EditorGroup = "EDITOR" - ViewerGroup = "VIEWER" -) diff --git a/pkg/query-service/dao/interface.go b/pkg/query-service/dao/interface.go index a1fdb9ed51..70c5c90823 100644 --- a/pkg/query-service/dao/interface.go +++ b/pkg/query-service/dao/interface.go @@ -5,6 +5,7 @@ import ( "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/types/authtypes" ) type ModelDao interface { @@ -22,13 +23,9 @@ type Queries interface { GetUsers(ctx context.Context) ([]types.GettableUser, *model.ApiError) GetUsersWithOpts(ctx context.Context, limit int) ([]types.GettableUser, *model.ApiError) - GetGroup(ctx context.Context, id string) (*types.Group, *model.ApiError) - GetGroupByName(ctx context.Context, name string) (*types.Group, *model.ApiError) - GetGroups(ctx context.Context) ([]types.Group, *model.ApiError) - GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError) GetUsersByOrg(ctx context.Context, orgId string) ([]types.GettableUser, *model.ApiError) - GetUsersByGroup(ctx context.Context, groupId string) ([]types.GettableUser, *model.ApiError) + GetUsersByRole(ctx context.Context, role authtypes.Role) ([]types.GettableUser, *model.ApiError) GetApdexSettings(ctx context.Context, orgID string, services []string) ([]types.ApdexSettings, *model.ApiError) @@ -43,14 +40,11 @@ type Mutations interface { EditUser(ctx context.Context, update *types.User) (*types.User, *model.ApiError) DeleteUser(ctx context.Context, id string) *model.ApiError - CreateGroup(ctx context.Context, group *types.Group) (*types.Group, *model.ApiError) - DeleteGroup(ctx context.Context, id string) *model.ApiError - CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError UpdateUserPassword(ctx context.Context, hash, userId string) *model.ApiError - UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError + UpdateUserRole(ctx context.Context, userId string, role authtypes.Role) *model.ApiError SetApdexSettings(ctx context.Context, orgID string, set *types.ApdexSettings) *model.ApiError } diff --git a/pkg/query-service/dao/sqlite/connection.go b/pkg/query-service/dao/sqlite/connection.go index 060deed42a..1138b5d82d 100644 --- a/pkg/query-service/dao/sqlite/connection.go +++ b/pkg/query-service/dao/sqlite/connection.go @@ -9,7 +9,6 @@ import ( "github.com/SigNoz/signoz/pkg/types" "github.com/pkg/errors" "github.com/uptrace/bun" - "go.uber.org/zap" ) type ModelDaoSqlite struct { @@ -24,12 +23,8 @@ func InitDB(sqlStore sqlstore.SQLStore) (*ModelDaoSqlite, error) { if err := mds.initializeOrgPreferences(ctx); err != nil { return nil, err } - if err := mds.initializeRBAC(ctx); err != nil { - return nil, err - } telemetry.GetInstance().SetUserCountCallback(mds.GetUserCount) - telemetry.GetInstance().SetUserRoleCallback(mds.GetUserRole) telemetry.GetInstance().SetGetUsersCallback(mds.GetUsers) return mds, nil @@ -76,42 +71,3 @@ func (mds *ModelDaoSqlite) initializeOrgPreferences(ctx context.Context) error { return nil } - -// initializeRBAC creates the ADMIN, EDITOR and VIEWER groups if they are not present. -func (mds *ModelDaoSqlite) initializeRBAC(ctx context.Context) error { - f := func(groupName string) error { - _, err := mds.createGroupIfNotPresent(ctx, groupName) - return errors.Wrap(err, "Failed to create group") - } - - if err := f(constants.AdminGroup); err != nil { - return err - } - if err := f(constants.EditorGroup); err != nil { - return err - } - if err := f(constants.ViewerGroup); err != nil { - return err - } - - return nil -} - -func (mds *ModelDaoSqlite) createGroupIfNotPresent(ctx context.Context, - name string) (*types.Group, error) { - - group, err := mds.GetGroupByName(ctx, name) - if err != nil { - return nil, errors.Wrap(err.Err, "Failed to query for root group") - } - if group != nil { - return group, nil - } - - zap.L().Debug("group is not found, creating it", zap.String("group_name", name)) - group, cErr := mds.CreateGroup(ctx, &types.Group{Name: name}) - if cErr != nil { - return nil, cErr.Err - } - return group, nil -} diff --git a/pkg/query-service/dao/sqlite/rbac.go b/pkg/query-service/dao/sqlite/rbac.go index 1ac8f6de6a..1282b75ab5 100644 --- a/pkg/query-service/dao/sqlite/rbac.go +++ b/pkg/query-service/dao/sqlite/rbac.go @@ -7,7 +7,7 @@ import ( "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/telemetry" "github.com/SigNoz/signoz/pkg/types" - "github.com/google/uuid" + "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/pkg/errors" ) @@ -163,11 +163,11 @@ func (mds *ModelDaoSqlite) UpdateUserPassword(ctx context.Context, passwordHash, return nil } -func (mds *ModelDaoSqlite) UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError { +func (mds *ModelDaoSqlite) UpdateUserRole(ctx context.Context, userId string, role authtypes.Role) *model.ApiError { _, err := mds.bundb.NewUpdate(). Model(&types.User{}). - Set("group_id = ?", groupId). + Set("role = ?", role). Where("id = ?", userId). Exec(ctx) @@ -207,10 +207,8 @@ func (mds *ModelDaoSqlite) GetUser(ctx context.Context, users := []types.GettableUser{} query := mds.bundb.NewSelect(). Table("users"). - Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). - ColumnExpr("g.name as role"). - ColumnExpr("o.display_name as organization"). - Join("JOIN groups g ON g.id = users.group_id"). + Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role"). + ColumnExpr("o.name as organization"). Join("JOIN organizations o ON o.id = users.org_id"). Where("users.id = ?", id) @@ -244,10 +242,8 @@ func (mds *ModelDaoSqlite) GetUserByEmail(ctx context.Context, users := []types.GettableUser{} query := mds.bundb.NewSelect(). Table("users"). - Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). - ColumnExpr("g.name as role"). - ColumnExpr("o.display_name as organization"). - Join("JOIN groups g ON g.id = users.group_id"). + Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role"). + ColumnExpr("o.name as organization"). Join("JOIN organizations o ON o.id = users.org_id"). Where("users.email = ?", email) @@ -279,10 +275,9 @@ func (mds *ModelDaoSqlite) GetUsersWithOpts(ctx context.Context, limit int) ([]t query := mds.bundb.NewSelect(). Table("users"). - Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). - ColumnExpr("g.name as role"). - ColumnExpr("o.display_name as organization"). - Join("JOIN groups g ON g.id = users.group_id"). + Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role"). + ColumnExpr("users.role as role"). + ColumnExpr("o.name as organization"). Join("JOIN organizations o ON o.id = users.org_id") if limit > 0 { @@ -303,10 +298,9 @@ func (mds *ModelDaoSqlite) GetUsersByOrg(ctx context.Context, query := mds.bundb.NewSelect(). Table("users"). - Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). - ColumnExpr("g.name as role"). - ColumnExpr("o.display_name as organization"). - Join("JOIN groups g ON g.id = users.group_id"). + Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role"). + ColumnExpr("users.role as role"). + ColumnExpr("o.name as organization"). Join("JOIN organizations o ON o.id = users.org_id"). Where("users.org_id = ?", orgId) @@ -317,19 +311,16 @@ func (mds *ModelDaoSqlite) GetUsersByOrg(ctx context.Context, return users, nil } -func (mds *ModelDaoSqlite) GetUsersByGroup(ctx context.Context, - groupId string) ([]types.GettableUser, *model.ApiError) { - +func (mds *ModelDaoSqlite) GetUsersByRole(ctx context.Context, role authtypes.Role) ([]types.GettableUser, *model.ApiError) { users := []types.GettableUser{} query := mds.bundb.NewSelect(). Table("users"). - Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). - ColumnExpr("g.name as role"). - ColumnExpr("o.display_name as organization"). - Join("JOIN groups g ON g.id = users.group_id"). + Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role"). + ColumnExpr("users.role as role"). + ColumnExpr("o.name as organization"). Join("JOIN organizations o ON o.id = users.org_id"). - Where("users.group_id = ?", groupId) + Where("users.role = ?", role) err := query.Scan(ctx, &users) if err != nil { @@ -338,98 +329,7 @@ func (mds *ModelDaoSqlite) GetUsersByGroup(ctx context.Context, return users, nil } -func (mds *ModelDaoSqlite) CreateGroup(ctx context.Context, - group *types.Group) (*types.Group, *model.ApiError) { - - group.ID = uuid.NewString() - - if _, err := mds.bundb.NewInsert(). - Model(group). - Exec(ctx); err != nil { - return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err} - } - - return group, nil -} - -func (mds *ModelDaoSqlite) DeleteGroup(ctx context.Context, id string) *model.ApiError { - - _, err := mds.bundb.NewDelete(). - Model(&types.Group{}). - Where("id = ?", id). - Exec(ctx) - - if err != nil { - return &model.ApiError{Typ: model.ErrorInternal, Err: err} - } - return nil -} - -func (mds *ModelDaoSqlite) GetGroup(ctx context.Context, - id string) (*types.Group, *model.ApiError) { - - groups := []types.Group{} - if err := mds.bundb.NewSelect(). - Model(&groups). - Where("id = ?", id). - Scan(ctx); err != nil { - return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err} - } - - if len(groups) > 1 { - return nil, &model.ApiError{ - Typ: model.ErrorInternal, - Err: errors.New("Found multiple groups with same ID."), - } - } - - if len(groups) == 0 { - return nil, nil - } - return &groups[0], nil -} - -func (mds *ModelDaoSqlite) GetGroupByName(ctx context.Context, - name string) (*types.Group, *model.ApiError) { - - groups := []types.Group{} - err := mds.bundb.NewSelect(). - Model(&groups). - Where("name = ?", name). - Scan(ctx) - if err != nil { - return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err} - } - - if len(groups) > 1 { - return nil, &model.ApiError{ - Typ: model.ErrorInternal, - Err: errors.New("Found multiple groups with same name"), - } - } - - if len(groups) == 0 { - return nil, nil - } - - return &groups[0], nil -} - -// TODO(nitya): should have org id -func (mds *ModelDaoSqlite) GetGroups(ctx context.Context) ([]types.Group, *model.ApiError) { - - groups := []types.Group{} - if err := mds.bundb.NewSelect(). - Model(&groups). - Scan(ctx); err != nil { - return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err} - } - - return groups, nil -} - -func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context, - req *types.ResetPasswordRequest) *model.ApiError { +func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError { if _, err := mds.bundb.NewInsert(). Model(req). @@ -439,8 +339,7 @@ func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context, return nil } -func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context, - token string) *model.ApiError { +func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError { _, err := mds.bundb.NewDelete(). Model(&types.ResetPasswordRequest{}). Where("token = ?", token). @@ -452,8 +351,7 @@ func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context, return nil } -func (mds *ModelDaoSqlite) GetResetPasswordEntry(ctx context.Context, - token string) (*types.ResetPasswordRequest, *model.ApiError) { +func (mds *ModelDaoSqlite) GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError) { entries := []types.ResetPasswordRequest{} @@ -491,14 +389,6 @@ func (mds *ModelDaoSqlite) PrecheckLogin(ctx context.Context, email, sourceUrl s return resp, nil } -func (mds *ModelDaoSqlite) GetUserRole(ctx context.Context, groupId string) (string, error) { - role, err := mds.GetGroup(ctx, groupId) - if err != nil || role == nil { - return "", err - } - return role.Name, nil -} - func (mds *ModelDaoSqlite) GetUserCount(ctx context.Context) (int, error) { users, err := mds.GetUsers(ctx) if err != nil { diff --git a/pkg/query-service/main.go b/pkg/query-service/main.go index e2a0f6114e..c8793b8b25 100644 --- a/pkg/query-service/main.go +++ b/pkg/query-service/main.go @@ -10,7 +10,6 @@ import ( "github.com/SigNoz/signoz/pkg/config/envprovider" "github.com/SigNoz/signoz/pkg/config/fileprovider" "github.com/SigNoz/signoz/pkg/query-service/app" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/types/authtypes" @@ -139,10 +138,6 @@ func main() { logger.Fatal("Could not start servers", zap.Error(err)) } - if err := auth.InitAuthCache(context.Background()); err != nil { - logger.Fatal("Failed to initialize auth cache", zap.Error(err)) - } - signoz.Start(context.Background()) if err := signoz.Wait(context.Background()); err != nil { diff --git a/pkg/query-service/rules/manager.go b/pkg/query-service/rules/manager.go index 2b99d7da5c..87f23725e7 100644 --- a/pkg/query-service/rules/manager.go +++ b/pkg/query-service/rules/manager.go @@ -332,9 +332,9 @@ func (m *Manager) Stop(ctx context.Context) { // EditRuleDefinition writes the rule definition to the // datastore and also updates the rule executor func (m *Manager) EditRule(ctx context.Context, ruleStr string, idStr string) error { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return errors.New("claims not found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return err } ruleUUID, err := valuer.NewUUID(idStr) @@ -469,9 +469,9 @@ func (m *Manager) DeleteRule(ctx context.Context, idStr string) error { return fmt.Errorf("delete rule received an rule id in invalid format, must be a valid uuid-v7") } - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return errors.New("claims not found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return err } _, err = m.ruleStore.GetStoredRule(ctx, id) @@ -523,9 +523,9 @@ func (m *Manager) deleteTask(taskName string) { // CreateRule stores rule def into db and also // starts an executor for the rule func (m *Manager) CreateRule(ctx context.Context, ruleStr string) (*ruletypes.GettableRule, error) { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, errors.New("claims not found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return nil, err } parsedRule, err := ruletypes.ParsePostableRule([]byte(ruleStr)) @@ -804,9 +804,9 @@ func (m *Manager) ListActiveRules() ([]Rule, error) { } func (m *Manager) ListRuleStates(ctx context.Context) (*ruletypes.GettableRules, error) { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, errors.New("claims not found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return nil, err } // fetch rules from DB storedRules, err := m.ruleStore.GetStoredRules(ctx, claims.OrgID) @@ -918,9 +918,9 @@ func (m *Manager) syncRuleStateWithTask(ctx context.Context, orgID string, taskN // - re-deploy or undeploy task as necessary // - update the patched rule in the DB func (m *Manager) PatchRule(ctx context.Context, ruleStr string, ruleIdStr string) (*ruletypes.GettableRule, error) { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, errors.New("claims not found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return nil, err } ruleID, err := valuer.NewUUID(ruleIdStr) @@ -1020,9 +1020,9 @@ func (m *Manager) TestNotification(ctx context.Context, ruleStr string) (int, *m } func (m *Manager) GetAlertDetailsForMetricNames(ctx context.Context, metricNames []string) (map[string][]ruletypes.GettableRule, *model.ApiError) { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return nil, &model.ApiError{Typ: model.ErrorExec, Err: errors.New("claims not found in context")} + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return nil, &model.ApiError{Typ: model.ErrorExec, Err: err} } result := make(map[string][]ruletypes.GettableRule) diff --git a/pkg/query-service/telemetry/telemetry.go b/pkg/query-service/telemetry/telemetry.go index ede62dcfa6..e01f74ec6b 100644 --- a/pkg/query-service/telemetry/telemetry.go +++ b/pkg/query-service/telemetry/telemetry.go @@ -206,7 +206,6 @@ type Telemetry struct { alertsInfoCallback func(ctx context.Context) (*model.AlertsInfo, error) userCountCallback func(ctx context.Context) (int, error) - userRoleCallback func(ctx context.Context, groupId string) (string, error) getUsersCallback func(ctx context.Context) ([]types.GettableUser, *model.ApiError) dashboardsInfoCallback func(ctx context.Context) (*model.DashboardsInfo, error) savedViewsInfoCallback func(ctx context.Context) (*model.SavedViewsInfo, error) @@ -220,10 +219,6 @@ func (a *Telemetry) SetUserCountCallback(callback func(ctx context.Context) (int a.userCountCallback = callback } -func (a *Telemetry) SetUserRoleCallback(callback func(ctx context.Context, groupId string) (string, error)) { - a.userRoleCallback = callback -} - func (a *Telemetry) SetGetUsersCallback(callback func(ctx context.Context) ([]types.GettableUser, *model.ApiError)) { a.getUsersCallback = callback } @@ -555,21 +550,12 @@ func (a *Telemetry) IdentifyUser(user *types.User) { if !a.isTelemetryEnabled() || a.isTelemetryAnonymous() { return } - // extract user group from user.groupId - role, _ := a.userRoleCallback(context.Background(), user.GroupID) if a.saasOperator != nil { - if role != "" { - _ = a.saasOperator.Enqueue(analytics.Identify{ - UserId: a.userEmail, - Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email).Set("role", role), - }) - } else { - _ = a.saasOperator.Enqueue(analytics.Identify{ - UserId: a.userEmail, - Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email), - }) - } + _ = a.saasOperator.Enqueue(analytics.Identify{ + UserId: a.userEmail, + Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email).Set("role", user.Role), + }) _ = a.saasOperator.Enqueue(analytics.Group{ UserId: a.userEmail, diff --git a/pkg/query-service/tests/integration/filter_suggestions_test.go b/pkg/query-service/tests/integration/filter_suggestions_test.go index f37264d2cc..afb2eb80e8 100644 --- a/pkg/query-service/tests/integration/filter_suggestions_test.go +++ b/pkg/query-service/tests/integration/filter_suggestions_test.go @@ -10,9 +10,9 @@ import ( "testing" "github.com/SigNoz/signoz/pkg/http/middleware" + "github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/query-service/app" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/featureManager" @@ -310,7 +310,7 @@ func NewFilterSuggestionsTestBed(t *testing.T) *FilterSuggestionsTestBed { router := app.NewRouter() //add the jwt middleware router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) - am := app.NewAuthMiddleware(auth.GetUserFromReqContext) + am := middleware.NewAuthZ(instrumentationtest.New().Logger()) apiHandler.RegisterRoutes(router, am) apiHandler.RegisterQueryRangeV3Routes(router, am) diff --git a/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go b/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go index 6f0bc47dc5..f8b4a4d4a5 100644 --- a/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go +++ b/pkg/query-service/tests/integration/signoz_cloud_integrations_test.go @@ -11,9 +11,9 @@ import ( "github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization" + "github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest" "github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/featureManager" "github.com/SigNoz/signoz/pkg/query-service/utils" @@ -373,7 +373,7 @@ func NewCloudIntegrationsTestBed(t *testing.T, testDB sqlstore.SQLStore) *CloudI router := app.NewRouter() router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) - am := app.NewAuthMiddleware(auth.GetUserFromReqContext) + am := middleware.NewAuthZ(instrumentationtest.New().Logger()) apiHandler.RegisterRoutes(router, am) apiHandler.RegisterCloudIntegrationsRoutes(router, am) diff --git a/pkg/query-service/tests/integration/signoz_integrations_test.go b/pkg/query-service/tests/integration/signoz_integrations_test.go index cfd3154531..f20764d28d 100644 --- a/pkg/query-service/tests/integration/signoz_integrations_test.go +++ b/pkg/query-service/tests/integration/signoz_integrations_test.go @@ -9,11 +9,11 @@ import ( "time" "github.com/SigNoz/signoz/pkg/http/middleware" + "github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations" "github.com/SigNoz/signoz/pkg/query-service/app/integrations" - "github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/featureManager" "github.com/SigNoz/signoz/pkg/query-service/model" @@ -580,7 +580,7 @@ func NewIntegrationsTestBed(t *testing.T, testDB sqlstore.SQLStore) *Integration router := app.NewRouter() router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) - am := app.NewAuthMiddleware(auth.GetUserFromReqContext) + am := middleware.NewAuthZ(instrumentationtest.New().Logger()) apiHandler.RegisterRoutes(router, am) apiHandler.RegisterIntegrationRoutes(router, am) diff --git a/pkg/query-service/tests/integration/test_utils.go b/pkg/query-service/tests/integration/test_utils.go index 53bbf4a670..2d8f11009b 100644 --- a/pkg/query-service/tests/integration/test_utils.go +++ b/pkg/query-service/tests/integration/test_utils.go @@ -20,7 +20,6 @@ import ( "github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app/clickhouseReader" "github.com/SigNoz/signoz/pkg/query-service/auth" - "github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/sqlstore" @@ -158,13 +157,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model return nil, model.InternalError(err) } - group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup) - if apiErr != nil { - return nil, model.InternalError(apiErr) - } - - auth.InitAuthCache(ctx) - userId := uuid.NewString() return dao.DB().CreateUser( @@ -175,7 +167,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model Email: userId[:8] + "test@test.com", Password: "test", OrgID: organization.ID.StringValue(), - GroupID: group.ID, + Role: authtypes.RoleAdmin.String(), }, true, ) diff --git a/pkg/query-service/utils/testutils.go b/pkg/query-service/utils/testutils.go index d14fac4d99..a1e3110f5b 100644 --- a/pkg/query-service/utils/testutils.go +++ b/pkg/query-service/utils/testutils.go @@ -58,6 +58,7 @@ func NewTestSqliteDB(t *testing.T) (sqlStore sqlstore.SQLStore, testDBFilePath s sqlmigration.NewAddVirtualFieldsFactory(), sqlmigration.NewUpdateIntegrationsFactory(sqlStore), sqlmigration.NewUpdateOrganizationsFactory(sqlStore), + sqlmigration.NewDropGroupsFactory(sqlStore), ), ) if err != nil { diff --git a/pkg/ruler/rulestore/sqlrulestore/maintenance.go b/pkg/ruler/rulestore/sqlrulestore/maintenance.go index eae6727923..c6bb000f0a 100644 --- a/pkg/ruler/rulestore/sqlrulestore/maintenance.go +++ b/pkg/ruler/rulestore/sqlrulestore/maintenance.go @@ -2,7 +2,6 @@ package sqlrulestore import ( "context" - "errors" "time" "github.com/SigNoz/signoz/pkg/sqlstore" @@ -60,10 +59,9 @@ func (r *maintenance) GetPlannedMaintenanceByID(ctx context.Context, id valuer.U } func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance) (valuer.UUID, error) { - - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return valuer.UUID{}, errors.New("no claims found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return valuer.UUID{}, err } storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{ @@ -100,7 +98,7 @@ func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance }) } - err := r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error { + err = r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error { _, err := r.sqlstore. BunDBCtx(ctx). NewInsert(). @@ -147,9 +145,9 @@ func (r *maintenance) DeletePlannedMaintenance(ctx context.Context, id valuer.UU } func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance, id valuer.UUID) error { - claims, ok := authtypes.ClaimsFromContext(ctx) - if !ok { - return errors.New("no claims found in context") + claims, err := authtypes.ClaimsFromContext(ctx) + if err != nil { + return err } storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{ @@ -186,7 +184,7 @@ func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ru }) } - err := r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error { + err = r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error { _, err := r.sqlstore. BunDBCtx(ctx). NewUpdate(). diff --git a/pkg/signoz/provider.go b/pkg/signoz/provider.go index d3ae5e0886..1e9f1878a2 100644 --- a/pkg/signoz/provider.go +++ b/pkg/signoz/provider.go @@ -73,6 +73,7 @@ func NewSQLMigrationProviderFactories(sqlstore sqlstore.SQLStore) factory.NamedM sqlmigration.NewAddVirtualFieldsFactory(), sqlmigration.NewUpdateIntegrationsFactory(sqlstore), sqlmigration.NewUpdateOrganizationsFactory(sqlstore), + sqlmigration.NewDropGroupsFactory(sqlstore), ) } diff --git a/pkg/signoz/signoz.go b/pkg/signoz/signoz.go index fe1a09a652..0a1ab8b757 100644 --- a/pkg/signoz/signoz.go +++ b/pkg/signoz/signoz.go @@ -19,12 +19,13 @@ import ( type SigNoz struct { *factory.Registry - Cache cache.Cache - Web web.Web - SQLStore sqlstore.SQLStore - TelemetryStore telemetrystore.TelemetryStore - Prometheus prometheus.Prometheus - Alertmanager alertmanager.Alertmanager + Instrumentation instrumentation.Instrumentation + Cache cache.Cache + Web web.Web + SQLStore sqlstore.SQLStore + TelemetryStore telemetrystore.TelemetryStore + Prometheus prometheus.Prometheus + Alertmanager alertmanager.Alertmanager } func New( @@ -144,12 +145,13 @@ func New( } return &SigNoz{ - Registry: registry, - Cache: cache, - Web: web, - SQLStore: sqlstore, - TelemetryStore: telemetrystore, - Prometheus: prometheus, - Alertmanager: alertmanager, + Registry: registry, + Instrumentation: instrumentation, + Cache: cache, + Web: web, + SQLStore: sqlstore, + TelemetryStore: telemetrystore, + Prometheus: prometheus, + Alertmanager: alertmanager, }, nil } diff --git a/pkg/sqlmigration/013_update_organization.go b/pkg/sqlmigration/013_update_organization.go index 6cfd6fcb8e..cec593b15f 100644 --- a/pkg/sqlmigration/013_update_organization.go +++ b/pkg/sqlmigration/013_update_organization.go @@ -87,14 +87,14 @@ func (migration *updateOrganization) Up(ctx context.Context, db *bun.DB) error { // since organizations, users has created_at as integer instead of timestamp for _, table := range []string{"organizations", "users", "invites"} { - if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil { + if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "created_at"); err != nil { return err } } // migrate is_anonymous and has_opted_updates to boolean from int for _, column := range []string{"is_anonymous", "has_opted_updates"} { - if err := migration.store.Dialect().MigrateIntToBoolean(ctx, tx, "organizations", column); err != nil { + if err := migration.store.Dialect().IntToBoolean(ctx, tx, "organizations", column); err != nil { return err } } diff --git a/pkg/sqlmigration/016_pat_org_domains.go b/pkg/sqlmigration/016_pat_org_domains.go index b169f456f7..eff3125fcc 100644 --- a/pkg/sqlmigration/016_pat_org_domains.go +++ b/pkg/sqlmigration/016_pat_org_domains.go @@ -66,16 +66,16 @@ func (migration *updatePatAndOrgDomains) Up(ctx context.Context, db *bun.DB) err } } - if err := updateOrgId(ctx, tx, "org_domains"); err != nil { + if err := updateOrgId(ctx, tx); err != nil { return err } // change created_at and updated_at from integer to timestamp for _, table := range []string{"personal_access_tokens", "org_domains"} { - if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil { + if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "created_at"); err != nil { return err } - if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "updated_at"); err != nil { + if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "updated_at"); err != nil { return err } } @@ -96,7 +96,7 @@ func (migration *updatePatAndOrgDomains) Down(ctx context.Context, db *bun.DB) e return nil } -func updateOrgId(ctx context.Context, tx bun.Tx, table string) error { +func updateOrgId(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewCreateTable(). Model(&struct { bun.BaseModel `bun:"table:org_domains_new"` diff --git a/pkg/sqlmigration/026_update_integrations.go b/pkg/sqlmigration/026_update_integrations.go index 3fecb03f68..ceb7be5ede 100644 --- a/pkg/sqlmigration/026_update_integrations.go +++ b/pkg/sqlmigration/026_update_integrations.go @@ -359,7 +359,20 @@ func (migration *updateIntegrations) CopyOldCloudIntegrationServicesToNewCloudIn } func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID string) error { - user := &types.User{} + type oldUser struct { + bun.BaseModel `bun:"table:users"` + + types.TimeAuditable + ID string `bun:"id,pk,type:text" json:"id"` + Name string `bun:"name,type:text,notnull" json:"name"` + Email string `bun:"email,type:text,notnull,unique" json:"email"` + Password string `bun:"password,type:text,notnull" json:"-"` + ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"` + GroupID string `bun:"group_id,type:text,notnull" json:"groupId"` + OrgID string `bun:"org_id,type:text,notnull" json:"orgId"` + } + + user := &oldUser{} err := tx.NewSelect().Model(user).Where("email = ?", "aws-integration@signoz.io").Scan(context.Background()) if err != nil { if err == sql.ErrNoRows { @@ -374,7 +387,7 @@ func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID } // new user - newUser := &types.User{ + newUser := &oldUser{ ID: uuid.New().String(), TimeAuditable: types.TimeAuditable{ CreatedAt: time.Now(), diff --git a/pkg/sqlmigration/029_drop_groups.go b/pkg/sqlmigration/029_drop_groups.go new file mode 100644 index 0000000000..211b9b68ec --- /dev/null +++ b/pkg/sqlmigration/029_drop_groups.go @@ -0,0 +1,141 @@ +package sqlmigration + +import ( + "context" + + "github.com/SigNoz/signoz/pkg/factory" + "github.com/SigNoz/signoz/pkg/sqlstore" + "github.com/SigNoz/signoz/pkg/types" + "github.com/uptrace/bun" + "github.com/uptrace/bun/migrate" +) + +type dropGroups struct { + sqlstore sqlstore.SQLStore +} + +func NewDropGroupsFactory(sqlstore sqlstore.SQLStore) factory.ProviderFactory[SQLMigration, Config] { + return factory.NewProviderFactory(factory.MustNewName("drop_groups"), func(ctx context.Context, providerSettings factory.ProviderSettings, config Config) (SQLMigration, error) { + return newDropGroups(ctx, providerSettings, config, sqlstore) + }) +} + +func newDropGroups(_ context.Context, _ factory.ProviderSettings, _ Config, sqlstore sqlstore.SQLStore) (SQLMigration, error) { + return &dropGroups{sqlstore: sqlstore}, nil +} + +func (migration *dropGroups) Register(migrations *migrate.Migrations) error { + if err := migrations.Register(migration.Up, migration.Down); err != nil { + return err + } + + return nil +} + +func (migration *dropGroups) Up(ctx context.Context, db *bun.DB) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + defer tx.Rollback() + + type Group struct { + bun.BaseModel `bun:"table:groups"` + + types.TimeAuditable + OrgID string `bun:"org_id,type:text"` + ID string `bun:"id,pk,type:text" json:"id"` + Name string `bun:"name,type:text,notnull,unique" json:"name"` + } + + type existingUser struct { + bun.BaseModel `bun:"table:users"` + + types.TimeAuditable + ID string `bun:"id,pk,type:text" json:"id"` + Name string `bun:"name,type:text,notnull" json:"name"` + Email string `bun:"email,type:text,notnull,unique" json:"email"` + Password string `bun:"password,type:text,notnull" json:"-"` + ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"` + GroupID string `bun:"group_id,type:text,notnull" json:"groupId"` + OrgID string `bun:"org_id,type:text,notnull" json:"orgId"` + } + + var existingUsers []*existingUser + if err := tx. + NewSelect(). + Model(&existingUsers). + Scan(ctx); err != nil { + return err + } + + var groups []*Group + if err := tx. + NewSelect(). + Model(&groups). + Scan(ctx); err != nil { + return err + } + + groupIDToRoleMap := make(map[string]string) + for _, group := range groups { + groupIDToRoleMap[group.ID] = group.Name + } + + roleToUserIDMap := make(map[string][]string) + for _, user := range existingUsers { + roleToUserIDMap[groupIDToRoleMap[user.GroupID]] = append(roleToUserIDMap[groupIDToRoleMap[user.GroupID]], user.ID) + } + + if err := migration.sqlstore.Dialect().DropColumnWithForeignKeyConstraint(ctx, tx, new(struct { + bun.BaseModel `bun:"table:users"` + + types.TimeAuditable + ID string `bun:"id,pk,type:text"` + Name string `bun:"name,type:text,notnull"` + Email string `bun:"email,type:text,notnull,unique"` + Password string `bun:"password,type:text,notnull"` + ProfilePictureURL string `bun:"profile_picture_url,type:text"` + OrgID string `bun:"org_id,type:text,notnull"` + }), "group_id"); err != nil { + return err + } + + if err := migration.sqlstore.Dialect().AddColumn(ctx, tx, "users", "role", "TEXT"); err != nil { + return err + } + + for role, userIDs := range roleToUserIDMap { + if _, err := tx. + NewUpdate(). + Table("users"). + Set("role = ?", role). + Where("id IN (?)", bun.In(userIDs)). + Exec(ctx); err != nil { + return err + } + } + + if err := migration.sqlstore.Dialect().AddNotNullDefaultToColumn(ctx, tx, "users", "role", "TEXT", "'VIEWER'"); err != nil { + return err + } + + if _, err := tx. + NewDropTable(). + Table("groups"). + IfExists(). + Exec(ctx); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (migration *dropGroups) Down(ctx context.Context, db *bun.DB) error { + return nil +} diff --git a/pkg/sqlmigration/sqlmigration.go b/pkg/sqlmigration/sqlmigration.go index 5a527b8974..8741e95798 100644 --- a/pkg/sqlmigration/sqlmigration.go +++ b/pkg/sqlmigration/sqlmigration.go @@ -6,7 +6,6 @@ import ( "github.com/SigNoz/signoz/pkg/factory" "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/migrate" ) @@ -66,29 +65,3 @@ func MustNew( } return migrations } - -func GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) { - var columnType string - var err error - - if bun.Dialect().Name() == dialect.SQLite { - err = bun.NewSelect(). - ColumnExpr("type"). - TableExpr("pragma_table_info(?)", table). - Where("name = ?", column). - Scan(ctx, &columnType) - } else { - err = bun.NewSelect(). - ColumnExpr("data_type"). - TableExpr("information_schema.columns"). - Where("table_name = ?", table). - Where("column_name = ?", column). - Scan(ctx, &columnType) - } - - if err != nil { - return "", err - } - - return columnType, nil -} diff --git a/pkg/sqlstore/sqlitesqlstore/dialect.go b/pkg/sqlstore/sqlitesqlstore/dialect.go index 268358a6a4..a27f4b7763 100644 --- a/pkg/sqlstore/sqlitesqlstore/dialect.go +++ b/pkg/sqlstore/sqlitesqlstore/dialect.go @@ -5,33 +5,53 @@ import ( "fmt" "reflect" "slices" + "strings" "github.com/SigNoz/signoz/pkg/errors" "github.com/uptrace/bun" ) -var ( - Identity = "id" - Integer = "INTEGER" - Text = "TEXT" +const ( + Identity string = "id" + Integer string = "INTEGER" + Text string = "TEXT" ) -var ( - Org = "org" - User = "user" - CloudIntegration = "cloud_integration" +const ( + Org string = "org" + User string = "user" + CloudIntegration string = "cloud_integration" ) -var ( - OrgReference = `("org_id") REFERENCES "organizations" ("id")` - UserReference = `("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE` - CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE` +const ( + OrgReference string = `("org_id") REFERENCES "organizations" ("id")` + UserReference string = `("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE` + CloudIntegrationReference string = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE` ) -type dialect struct { +const ( + OrgField string = "org_id" +) + +type dialect struct{} + +func (dialect *dialect) GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) { + var columnType string + + err := bun. + NewSelect(). + ColumnExpr("type"). + TableExpr("pragma_table_info(?)", table). + Where("name = ?", column). + Scan(ctx, &columnType) + if err != nil { + return "", err + } + + return columnType, nil } -func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { +func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { columnType, err := dialect.GetColumnType(ctx, bun, table, column) if err != nil { return err @@ -73,7 +93,7 @@ func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, return nil } -func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { +func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { columnExists, err := dialect.ColumnExists(ctx, bun, table, column) if err != nil { return err @@ -118,22 +138,6 @@ func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, ta return nil } -func (dialect *dialect) GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) { - var columnType string - - err := bun. - NewSelect(). - ColumnExpr("type"). - TableExpr("pragma_table_info(?)", table). - Where("name = ?", column). - Scan(ctx, &columnType) - if err != nil { - return "", err - } - - return columnType, nil -} - func (dialect *dialect) ColumnExists(ctx context.Context, bun bun.IDB, table string, column string) (bool, error) { var count int err := bun.NewSelect(). @@ -217,7 +221,6 @@ func (dialect *dialect) DropColumn(ctx context.Context, bun bun.IDB, table strin } func (dialect *dialect) TableExists(ctx context.Context, bun bun.IDB, table interface{}) (bool, error) { - count := 0 err := bun. NewSelect(). @@ -423,3 +426,63 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel return nil } + +func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bunIDB bun.IDB, model interface{}, column string) error { + existingTable := bunIDB.Dialect().Tables().Get(reflect.TypeOf(model)) + columnExists, err := dialect.ColumnExists(ctx, bunIDB, existingTable.Name, column) + if err != nil { + return err + } + + if !columnExists { + return nil + } + + newTableName := existingTable.Name + "_tmp" + + // Create the newTmpTable query + createTableQuery := bunIDB.NewCreateTable().Model(model).ModelTableExpr(newTableName) + + var columnNames []string + + for _, field := range existingTable.Fields { + if field.Name != column { + columnNames = append(columnNames, string(field.SQLName)) + } + + if field.Name == OrgField { + createTableQuery = createTableQuery.ForeignKey(OrgReference) + } + } + + // Disable foreign keys temporarily + if _, err := bunIDB.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { + return err + } + + if _, err = createTableQuery.Exec(ctx); err != nil { + return err + } + + // Copy data from old table to new table + if _, err := bunIDB.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", newTableName, strings.Join(columnNames, ", "), existingTable.Name)); err != nil { + return err + } + + _, err = bunIDB.NewDropTable().Table(existingTable.Name).Exec(ctx) + if err != nil { + return err + } + + _, err = bunIDB.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME TO %s", newTableName, existingTable.Name)) + if err != nil { + return err + } + + // Re-enable foreign keys + if _, err := bunIDB.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { + return err + } + + return nil +} diff --git a/pkg/sqlstore/sqlstore.go b/pkg/sqlstore/sqlstore.go index 8f4fead315..d3b6cb3b30 100644 --- a/pkg/sqlstore/sqlstore.go +++ b/pkg/sqlstore/sqlstore.go @@ -37,15 +37,42 @@ type SQLStoreHook interface { } type SQLDialect interface { - MigrateIntToTimestamp(context.Context, bun.IDB, string, string) error - MigrateIntToBoolean(context.Context, bun.IDB, string, string) error - AddNotNullDefaultToColumn(context.Context, bun.IDB, string, string, string, string) error + // Returns the type of the column for the given table and column. GetColumnType(context.Context, bun.IDB, string, string) (string, error) + + // Migrates an integer column to a timestamp column for the given table and column. + IntToTimestamp(context.Context, bun.IDB, string, string) error + + // Migrates an integer column to a boolean column for the given table and column. + IntToBoolean(context.Context, bun.IDB, string, string) error + + // Adds a not null default to the given column for the given table, column, columnType and defaultValue. + AddNotNullDefaultToColumn(context.Context, bun.IDB, string, string, string, string) error + + // Checks if a column exists in a table for the given table and column. ColumnExists(context.Context, bun.IDB, string, string) (bool, error) + + // Adds a column to a table for the given table, column and columnType. AddColumn(context.Context, bun.IDB, string, string, string) error - RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error) + + // Drops a column from a table for the given table and column. DropColumn(context.Context, bun.IDB, string, string) error + + // Renames a column in a table for the given table, old column name and new column name. + RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error) + + // Renames a table and modifies the given model for the given table, old model, new model, references and callback. The old model + // and new model must inherit bun.BaseModel. RenameTableAndModifyModel(context.Context, bun.IDB, interface{}, interface{}, []string, func(context.Context) error) error + + // Updates the primary key for the given table, old model, new model, reference and callback. The old model and new model + // must inherit bun.BaseModel. UpdatePrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error + + // Adds a primary key to the given table, old model, new model, reference and callback. The old model and new model + // must inherit bun.BaseModel. AddPrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error + + // Drops the column and the associated foreign key constraint for the given table and column. + DropColumnWithForeignKeyConstraint(context.Context, bun.IDB, interface{}, string) error } diff --git a/pkg/sqlstore/sqlstoretest/dialect.go b/pkg/sqlstore/sqlstoretest/dialect.go index d14926c3db..897f946ed8 100644 --- a/pkg/sqlstore/sqlstoretest/dialect.go +++ b/pkg/sqlstore/sqlstoretest/dialect.go @@ -9,11 +9,11 @@ import ( type dialect struct { } -func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { +func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { return nil } -func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { +func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { return nil } @@ -56,3 +56,7 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel func (dialect *dialect) IndexExists(ctx context.Context, bun bun.IDB, table string, index string) (bool, error) { return false, nil } + +func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bun bun.IDB, model interface{}, column string) error { + return nil +} diff --git a/pkg/types/authtypes/claims.go b/pkg/types/authtypes/claims.go new file mode 100644 index 0000000000..3ce3d79b43 --- /dev/null +++ b/pkg/types/authtypes/claims.go @@ -0,0 +1,83 @@ +package authtypes + +import ( + "log/slog" + "slices" + + "github.com/SigNoz/signoz/pkg/errors" + "github.com/golang-jwt/jwt/v5" +) + +var _ jwt.ClaimsValidator = (*Claims)(nil) + +type Claims struct { + jwt.RegisteredClaims + UserID string `json:"id"` + Email string `json:"email"` + Role Role `json:"role"` + OrgID string `json:"orgId"` +} + +func (c *Claims) Validate() error { + if c.UserID == "" { + return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "id is required") + } + + // The problem is that when the "role" field is missing entirely from the JSON (as opposed to being present but empty), the UnmarshalJSON method for Role isn't called at all. + // The JSON decoder just sets the Role field to its zero value (""). + if c.Role == "" { + return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "role is required") + } + + if c.OrgID == "" { + return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "orgId is required") + } + + return nil +} + +func (c *Claims) LogValue() slog.Value { + return slog.GroupValue( + slog.String("id", c.UserID), + slog.String("email", c.Email), + slog.String("role", c.Role.String()), + slog.String("orgId", c.OrgID), + slog.Time("exp", c.ExpiresAt.Time), + ) +} + +func (c *Claims) IsViewer() error { + if slices.Contains([]Role{RoleViewer, RoleEditor, RoleAdmin}, c.Role) { + return nil + } + + return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only viewers/editors/admins can access this resource") +} + +func (c *Claims) IsEditor() error { + if slices.Contains([]Role{RoleEditor, RoleAdmin}, c.Role) { + return nil + } + + return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only editors/admins can access this resource") +} + +func (c *Claims) IsAdmin() error { + if c.Role == RoleAdmin { + return nil + } + + return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only admins can access this resource") +} + +func (c *Claims) IsSelfAccess(id string) error { + if c.UserID == id { + return nil + } + + if c.Role == RoleAdmin { + return nil + } + + return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only the user/admin can access their own resource") +} diff --git a/pkg/types/authtypes/jwt.go b/pkg/types/authtypes/jwt.go index 7754eaf185..5b31584a89 100644 --- a/pkg/types/authtypes/jwt.go +++ b/pkg/types/authtypes/jwt.go @@ -2,24 +2,15 @@ package authtypes import ( "context" - "errors" - "fmt" "strings" "time" + "github.com/SigNoz/signoz/pkg/errors" "github.com/golang-jwt/jwt/v5" ) type jwtClaimsKey struct{} -type Claims struct { - jwt.RegisteredClaims - UserID string `json:"id"` - GroupID string `json:"gid"` - Email string `json:"email"` - OrgID string `json:"orgId"` -} - type JWT struct { JwtSecret string JwtExpiry time.Duration @@ -34,16 +25,6 @@ func NewJWT(jwtSecret string, jwtExpiry time.Duration, jwtRefresh time.Duration) } } -func parseBearerAuth(auth string) (string, bool) { - const prefix = "Bearer " - // Case insensitive prefix match - if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { - return "", false - } - - return auth[len(prefix):], true -} - func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context.Context, error) { var value string for _, v := range values { @@ -54,7 +35,7 @@ func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context } if value == "" { - return ctx, errors.New("missing Authorization header") + return ctx, errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "missing authorization header") } // parse from @@ -73,24 +54,18 @@ func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context } func (j *JWT) Claims(jwtStr string) (Claims, error) { - token, err := jwt.ParseWithClaims(jwtStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { + claims := Claims{} + _, err := jwt.ParseWithClaims(jwtStr, &claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unknown signing algo: %v", token.Header["alg"]) + return nil, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unrecognized signing algorithm: %s", token.Method.Alg()) } return []byte(j.JwtSecret), nil }) - if err != nil { - return Claims{}, fmt.Errorf("failed to parse jwt token: %w", err) + return Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to parse jwt token") } - // Type assertion to retrieve claims from the token - userClaims, ok := token.Claims.(*Claims) - if !ok { - return Claims{}, errors.New("failed to retrieve claims from token") - } - - return *userClaims, nil + return claims, nil } // NewContextWithClaims attaches individual claims to the context. @@ -106,36 +81,62 @@ func (j *JWT) signToken(claims Claims) (string, error) { } // AccessToken creates an access token with the provided claims -func (j *JWT) AccessToken(orgId, userId, groupId, email string) (string, error) { +func (j *JWT) AccessToken(orgId, userId, email string, role Role) (string, Claims, error) { claims := Claims{ - UserID: userId, - GroupID: groupId, - Email: email, - OrgID: orgId, + UserID: userId, + Role: role, + Email: email, + OrgID: orgId, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtExpiry)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } - return j.signToken(claims) + + token, err := j.signToken(claims) + if err != nil { + return "", Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to sign token") + } + + return token, claims, nil } // RefreshToken creates a refresh token with the provided claims -func (j *JWT) RefreshToken(orgId, userId, groupId, email string) (string, error) { +func (j *JWT) RefreshToken(orgId, userId, email string, role Role) (string, Claims, error) { claims := Claims{ - UserID: userId, - GroupID: groupId, - Email: email, - OrgID: orgId, + UserID: userId, + Role: role, + Email: email, + OrgID: orgId, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtRefresh)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } - return j.signToken(claims) + + token, err := j.signToken(claims) + if err != nil { + return "", Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to sign token") + } + + return token, claims, nil } -func ClaimsFromContext(ctx context.Context) (Claims, bool) { +func ClaimsFromContext(ctx context.Context) (Claims, error) { claims, ok := ctx.Value(jwtClaimsKey{}).(Claims) - return claims, ok + if !ok { + return Claims{}, errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated") + } + + return claims, nil +} + +func parseBearerAuth(auth string) (string, bool) { + const prefix = "Bearer " + // Case insensitive prefix match + if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) { + return "", false + } + + return auth[len(prefix):], true } diff --git a/pkg/types/authtypes/jwt_test.go b/pkg/types/authtypes/jwt_test.go index bfd9749f54..b98aa3144d 100644 --- a/pkg/types/authtypes/jwt_test.go +++ b/pkg/types/authtypes/jwt_test.go @@ -4,35 +4,36 @@ import ( "testing" "time" + "github.com/SigNoz/signoz/pkg/errors" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" ) -func TestGetAccessJwt(t *testing.T) { +func TestJwtAccessToken(t *testing.T) { jwtService := NewJWT("secret", time.Minute, time.Hour) - token, err := jwtService.AccessToken("orgId", "userId", "groupId", "email@example.com") + token, _, err := jwtService.AccessToken("orgId", "userId", "email@example.com", RoleAdmin) assert.NoError(t, err) assert.NotEmpty(t, token) } -func TestGetRefreshJwt(t *testing.T) { +func TestJwtRefreshToken(t *testing.T) { jwtService := NewJWT("secret", time.Minute, time.Hour) - token, err := jwtService.RefreshToken("orgId", "userId", "groupId", "email@example.com") + token, _, err := jwtService.RefreshToken("orgId", "userId", "email@example.com", RoleAdmin) assert.NoError(t, err) assert.NotEmpty(t, token) } -func TestGetJwtClaims(t *testing.T) { +func TestJwtClaims(t *testing.T) { jwtService := NewJWT("secret", time.Minute, time.Hour) // Create a valid token claims := Claims{ - UserID: "userId", - GroupID: "groupId", - Email: "email@example.com", - OrgID: "orgId", + UserID: "userId", + Role: RoleAdmin, + Email: "email@example.com", + OrgID: "orgId", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), IssuedAt: jwt.NewNumericDate(time.Now()), @@ -45,29 +46,28 @@ func TestGetJwtClaims(t *testing.T) { retrievedClaims, err := jwtService.Claims(tokenString) assert.NoError(t, err) assert.Equal(t, claims.UserID, retrievedClaims.UserID) - assert.Equal(t, claims.GroupID, retrievedClaims.GroupID) + assert.Equal(t, claims.Role, retrievedClaims.Role) assert.Equal(t, claims.Email, retrievedClaims.Email) assert.Equal(t, claims.OrgID, retrievedClaims.OrgID) } -func TestGetJwtClaimsInvalidToken(t *testing.T) { +func TestJwtClaimsInvalidToken(t *testing.T) { jwtService := NewJWT("secret", time.Minute, time.Hour) - // Test retrieving claims from an invalid token _, err := jwtService.Claims("invalid.token.string") assert.Error(t, err) assert.Contains(t, err.Error(), "token is malformed") } -func TestGetJwtClaimsExpiredToken(t *testing.T) { +func TestJwtClaimsExpiredToken(t *testing.T) { jwtService := NewJWT("secret", time.Minute, time.Hour) // Create an expired token claims := Claims{ - UserID: "userId", - GroupID: "groupId", - Email: "email@example.com", - OrgID: "orgId", + UserID: "userId", + Role: RoleAdmin, + Email: "email@example.com", + OrgID: "orgId", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), IssuedAt: jwt.NewNumericDate(time.Now()), @@ -81,15 +81,15 @@ func TestGetJwtClaimsExpiredToken(t *testing.T) { assert.Contains(t, err.Error(), "token is expired") } -func TestGetJwtClaimsInvalidSignature(t *testing.T) { +func TestJwtClaimsInvalidSignature(t *testing.T) { jwtService := NewJWT("secret", time.Minute, time.Hour) // Create a valid token claims := Claims{ - UserID: "userId", - GroupID: "groupId", - Email: "email@example.com", - OrgID: "orgId", + UserID: "userId", + Role: RoleAdmin, + Email: "email@example.com", + OrgID: "orgId", RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), }, @@ -106,6 +106,86 @@ func TestGetJwtClaimsInvalidSignature(t *testing.T) { assert.Contains(t, err.Error(), "signature is invalid") } +func TestJwtClaimsWithInvalidRole(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + claims := Claims{ + UserID: "userId", + Role: "INVALID_ROLE", + Email: "email@example.com", + OrgID: "orgId", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + validToken, err := jwtService.signToken(claims) + assert.NoError(t, err) + + _, err = jwtService.Claims(validToken) + assert.Error(t, err) + assert.True(t, errors.Ast(err, errors.TypeUnauthenticated)) +} + +func TestJwtClaimsMissingUserID(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + claims := Claims{ + UserID: "", + Role: RoleAdmin, + Email: "email@example.com", + OrgID: "orgId", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + validToken, err := jwtService.signToken(claims) + assert.NoError(t, err) + + _, err = jwtService.Claims(validToken) + assert.Error(t, err) + assert.True(t, errors.Ast(err, errors.TypeUnauthenticated)) +} + +func TestJwtClaimsMissingRole(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + claims := Claims{ + UserID: "userId", + Role: "", + Email: "email@example.com", + OrgID: "orgId", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + validToken, err := jwtService.signToken(claims) + assert.NoError(t, err) + + _, err = jwtService.Claims(validToken) + assert.Error(t, err) + assert.True(t, errors.Ast(err, errors.TypeUnauthenticated)) +} + +func TestJwtClaimsMissingOrgID(t *testing.T) { + jwtService := NewJWT("secret", time.Minute, time.Hour) + + claims := Claims{ + UserID: "userId", + Role: RoleAdmin, + Email: "email@example.com", + OrgID: "", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + }, + } + validToken, err := jwtService.signToken(claims) + assert.NoError(t, err) + + _, err = jwtService.Claims(validToken) + assert.Error(t, err) + assert.True(t, errors.Ast(err, errors.TypeUnauthenticated)) +} + func TestParseBearerAuth(t *testing.T) { tests := []struct { auth string diff --git a/pkg/types/authtypes/role.go b/pkg/types/authtypes/role.go new file mode 100644 index 0000000000..16ac7fc6cb --- /dev/null +++ b/pkg/types/authtypes/role.go @@ -0,0 +1,52 @@ +package authtypes + +import ( + "encoding/json" + + "github.com/SigNoz/signoz/pkg/errors" +) + +// Do not take inspiration from this. This is a hack to avoid using valuer.String and use upper case strings. +type Role string + +const ( + RoleAdmin Role = "ADMIN" + RoleEditor Role = "EDITOR" + RoleViewer Role = "VIEWER" +) + +func NewRole(role string) (Role, error) { + switch role { + case "ADMIN": + return RoleAdmin, nil + case "EDITOR": + return RoleEditor, nil + case "VIEWER": + return RoleViewer, nil + } + + return "", errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid role: %s", role) +} + +func (r Role) String() string { + return string(r) +} + +func (r *Role) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + role, err := NewRole(s) + if err != nil { + return err + } + + *r = role + return nil +} + +func (r Role) MarshalJSON() ([]byte, error) { + return json.Marshal(r.String()) +} diff --git a/pkg/types/user.go b/pkg/types/user.go index ba9e1eb965..21932af148 100644 --- a/pkg/types/user.go +++ b/pkg/types/user.go @@ -16,18 +16,8 @@ type Invite struct { Role string `bun:"role,type:text,notnull" json:"role"` } -type Group struct { - bun.BaseModel `bun:"table:groups"` - - TimeAuditable - OrgID string `bun:"org_id,type:text"` - ID string `bun:"id,pk,type:text" json:"id"` - Name string `bun:"name,type:text,notnull,unique" json:"name"` -} - type GettableUser struct { User - Role string `json:"role"` Organization string `json:"organization"` } @@ -40,7 +30,7 @@ type User struct { Email string `bun:"email,type:text,notnull,unique" json:"email"` Password string `bun:"password,type:text,notnull" json:"-"` ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"` - GroupID string `bun:"group_id,type:text,notnull" json:"groupId"` + Role string `bun:"role,type:text,notnull" json:"role"` OrgID string `bun:"org_id,type:text,notnull" json:"orgId"` } diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index bc891ff5e2..011ee2b55e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -16,6 +16,12 @@ pytest_plugins = [ def pytest_addoption(parser: pytest.Parser): + parser.addoption( + "--dev", + action="store", + default=False, + help="Run in dev mode. In this mode, the containers are not torn down after test run and are reused in subsequent test runs.", + ) parser.addoption( "--sqlstore-provider", action="store", diff --git a/tests/integration/fixtures/__init__.py b/tests/integration/fixtures/__init__.py index e69de29bb2..3d5ee4ee94 100644 --- a/tests/integration/fixtures/__init__.py +++ b/tests/integration/fixtures/__init__.py @@ -0,0 +1,3 @@ +from testcontainers.core.config import testcontainers_config as config + +config.ryuk_disabled = True diff --git a/tests/integration/fixtures/clickhouse.py b/tests/integration/fixtures/clickhouse.py index 573b5e11cd..e877b47325 100644 --- a/tests/integration/fixtures/clickhouse.py +++ b/tests/integration/fixtures/clickhouse.py @@ -1,3 +1,4 @@ +import dataclasses import os from typing import Any, Generator @@ -15,10 +16,46 @@ def clickhouse( network: Network, zookeeper: types.TestContainerDocker, request: pytest.FixtureRequest, + pytestconfig: pytest.Config, ) -> types.TestContainerClickhouse: """ Package-scoped fixture for Clickhouse TestContainer. """ + dev = request.config.getoption("--dev") + if dev: + container = pytestconfig.cache.get("clickhouse.container", None) + env = pytestconfig.cache.get("clickhouse.env", None) + + if container and env: + assert isinstance(container, dict) + assert isinstance(env, dict) + + test_container = types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + container["host_config"]["scheme"], + container["host_config"]["address"], + container["host_config"]["port"], + ), + container_config=types.TestContainerUrlConfig( + container["container_config"]["scheme"], + container["container_config"]["address"], + container["container_config"]["port"], + ), + ) + + connection = clickhouse_driver.connect( + user=env["SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_USERNAME"], + password=env["SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_PASSWORD"], + host=test_container.host_config.address, + port=test_container.host_config.port, + ) + + return types.TestContainerClickhouse( + container=test_container, + conn=connection, + env=env, + ) + version = request.config.getoption("--clickhouse-version") container = ClickHouseContainer( @@ -91,21 +128,37 @@ def clickhouse( ) def stop(): + if dev: + return + connection.close() container.stop(delete_volume=True) request.addfinalizer(stop) - return types.TestContainerClickhouse( - container=container, - host_config=types.TestContainerUrlConfig( - "tcp", container.get_container_host_ip(), container.get_exposed_port(9000) - ), - container_config=types.TestContainerUrlConfig( - "tcp", container.get_wrapped_container().name, 9000 + cached_clickhouse = types.TestContainerClickhouse( + container=types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + "tcp", + container.get_container_host_ip(), + container.get_exposed_port(9000), + ), + container_config=types.TestContainerUrlConfig( + "tcp", container.get_wrapped_container().name, 9000 + ), ), conn=connection, env={ - "SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_DSN": f"tcp://{container.username}:{container.password}@{container.get_wrapped_container().name}:{9000}" # pylint: disable=line-too-long + "SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_DSN": f"tcp://{container.username}:{container.password}@{container.get_wrapped_container().name}:{9000}", # pylint: disable=line-too-long + "SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_USERNAME": container.username, + "SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_PASSWORD": container.password, }, ) + + if dev: + pytestconfig.cache.set( + "clickhouse.container", dataclasses.asdict(cached_clickhouse.container) + ) + pytestconfig.cache.set("clickhouse.env", cached_clickhouse.env) + + return cached_clickhouse diff --git a/tests/integration/fixtures/http.py b/tests/integration/fixtures/http.py index 0de978f9c0..fd8f798df4 100644 --- a/tests/integration/fixtures/http.py +++ b/tests/integration/fixtures/http.py @@ -1,3 +1,4 @@ +import dataclasses from typing import List import pytest @@ -14,23 +15,44 @@ from fixtures import types @pytest.fixture(name="zeus", scope="package") def zeus( - network: Network, request: pytest.FixtureRequest -) -> types.TestContainerWiremock: + network: Network, + request: pytest.FixtureRequest, + pytestconfig: pytest.Config, +) -> types.TestContainerDocker: """ Package-scoped fixture for running zeus """ + dev = request.config.getoption("--dev") + if dev: + cached_zeus = pytestconfig.cache.get("zeus", None) + if cached_zeus: + return types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + cached_zeus["host_config"]["scheme"], + cached_zeus["host_config"]["address"], + cached_zeus["host_config"]["port"], + ), + container_config=types.TestContainerUrlConfig( + cached_zeus["container_config"]["scheme"], + cached_zeus["container_config"]["address"], + cached_zeus["container_config"]["port"], + ), + ) + container = WireMockContainer(image="wiremock/wiremock:2.35.1-1", secure=False) container.with_network(network) container.start() def stop(): + if dev: + return + container.stop(delete_volume=True) request.addfinalizer(stop) - return types.TestContainerWiremock( - container=container, + cached_zeus = types.TestContainerDocker( host_config=types.TestContainerUrlConfig( "http", container.get_container_host_ip(), container.get_exposed_port(8080) ), @@ -39,11 +61,16 @@ def zeus( ), ) + if dev: + pytestconfig.cache.set("zeus", dataclasses.asdict(cached_zeus)) + + return cached_zeus + @pytest.fixture(name="make_http_mocks", scope="function") def make_http_mocks(): - def _make_http_mocks(container: WireMockContainer, mappings: List[Mapping]): - Config.base_url = container.get_url("__admin") + def _make_http_mocks(container: types.TestContainerDocker, mappings: List[Mapping]): + Config.base_url = container.host_config.get("/__admin") for mapping in mappings: Mappings.create_mapping(mapping=mapping) diff --git a/tests/integration/fixtures/logger.py b/tests/integration/fixtures/logger.py new file mode 100644 index 0000000000..a8f27151a3 --- /dev/null +++ b/tests/integration/fixtures/logger.py @@ -0,0 +1,10 @@ +import logging + + +def setup_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) + return logger diff --git a/tests/integration/fixtures/migrator.py b/tests/integration/fixtures/migrator.py index d389cbfd99..66ef6e492c 100644 --- a/tests/integration/fixtures/migrator.py +++ b/tests/integration/fixtures/migrator.py @@ -10,10 +10,17 @@ def migrator( network: Network, clickhouse: types.TestContainerClickhouse, request: pytest.FixtureRequest, + pytestconfig: pytest.Config, ) -> None: """ Package-scoped fixture for running schema migrations. """ + dev = request.config.getoption("--dev") + if dev: + cached_migrator = pytestconfig.cache.get("migrator", None) + if cached_migrator is not None and cached_migrator is True: + return None + version = request.config.getoption("--schema-migrator-version") client = docker.from_env() @@ -53,3 +60,6 @@ def migrator( raise RuntimeError("failed to run migrations on clickhouse") container.remove() + + if dev: + pytestconfig.cache.set("migrator", True) diff --git a/tests/integration/fixtures/network.py b/tests/integration/fixtures/network.py index b0bcc3fc1f..d0bffb8714 100644 --- a/tests/integration/fixtures/network.py +++ b/tests/integration/fixtures/network.py @@ -1,18 +1,45 @@ import pytest from testcontainers.core.container import Network +from fixtures.logger import setup_logger + +logger = setup_logger(__name__) + @pytest.fixture(name="network", scope="package") -def network(request: pytest.FixtureRequest) -> Network: +def network(request: pytest.FixtureRequest, pytestconfig: pytest.Config) -> Network: """ Package-Scoped fixture for creating a network """ nw = Network() + + dev = request.config.getoption("--dev") + if dev: + cached_network = pytestconfig.cache.get("network", None) + if cached_network: + logger.info("Using cached Network(%s)", cached_network) + nw.id = cached_network["id"] + nw.name = cached_network["name"] + return nw + nw.create() def stop(): - nw.remove() + dev = request.config.getoption("--dev") + if dev: + logger.info( + "Skipping removal of Network(%s)", {"name": nw.name, "id": nw.id} + ) + else: + logger.info("Removing Network(%s)", {"name": nw.name, "id": nw.id}) + nw.remove() request.addfinalizer(stop) - return nw + cached_network = nw + if dev: + pytestconfig.cache.set( + "network", {"name": cached_network.name, "id": cached_network.id} + ) + + return cached_network diff --git a/tests/integration/fixtures/postgres.py b/tests/integration/fixtures/postgres.py index 7643e03307..94b805072a 100644 --- a/tests/integration/fixtures/postgres.py +++ b/tests/integration/fixtures/postgres.py @@ -1,3 +1,5 @@ +import dataclasses + import psycopg2 import pytest from testcontainers.core.container import Network @@ -8,11 +10,45 @@ from fixtures import types @pytest.fixture(name="postgres", scope="package") def postgres( - network: Network, request: pytest.FixtureRequest + network: Network, request: pytest.FixtureRequest, pytestconfig: pytest.Config ) -> types.TestContainerSQL: """ Package-scoped fixture for PostgreSQL TestContainer. """ + dev = request.config.getoption("--dev") + if dev: + container = pytestconfig.cache.get("postgres.container", None) + env = pytestconfig.cache.get("postgres.env", None) + + if container and env: + assert isinstance(container, dict) + assert isinstance(env, dict) + + test_container = types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + container["host_config"]["scheme"], + container["host_config"]["address"], + container["host_config"]["port"], + ), + container_config=types.TestContainerUrlConfig( + container["container_config"]["scheme"], + container["container_config"]["address"], + container["container_config"]["port"], + ), + ) + + return types.TestContainerSQL( + container=test_container, + conn=psycopg2.connect( + dbname=env["SIGNOZ_SQLSTORE_POSTGRES_DBNAME"], + user=env["SIGNOZ_SQLSTORE_POSTGRES_USER"], + password=env["SIGNOZ_SQLSTORE_POSTGRES_PASSWORD"], + host=test_container.host_config.address, + port=test_container.host_config.port, + ), + env=env, + ) + version = request.config.getoption("--postgres-version") container = PostgresContainer( @@ -35,24 +71,39 @@ def postgres( ) def stop(): + if dev: + return + connection.close() container.stop(delete_volume=True) request.addfinalizer(stop) - return types.TestContainerSQL( - container=container, - host_config=types.TestContainerUrlConfig( - "postgresql", - container.get_container_host_ip(), - container.get_exposed_port(5432), - ), - container_config=types.TestContainerUrlConfig( - "postgresql", container.get_wrapped_container().name, 5432 + cached_postgres = types.TestContainerSQL( + container=types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + "postgresql", + container.get_container_host_ip(), + container.get_exposed_port(5432), + ), + container_config=types.TestContainerUrlConfig( + "postgresql", container.get_wrapped_container().name, 5432 + ), ), conn=connection, env={ "SIGNOZ_SQLSTORE_PROVIDER": "postgres", "SIGNOZ_SQLSTORE_POSTGRES_DSN": f"postgresql://{container.username}:{container.password}@{container.get_wrapped_container().name}:{5432}/{container.dbname}", # pylint: disable=line-too-long + "SIGNOZ_SQLSTORE_POSTGRES_DBNAME": container.dbname, + "SIGNOZ_SQLSTORE_POSTGRES_USER": container.username, + "SIGNOZ_SQLSTORE_POSTGRES_PASSWORD": container.password, }, ) + + if dev: + pytestconfig.cache.set( + "postgres.container", dataclasses.asdict(cached_postgres.container) + ) + pytestconfig.cache.set("postgres.env", cached_postgres.env) + + return cached_postgres diff --git a/tests/integration/fixtures/signoz.py b/tests/integration/fixtures/signoz.py index d89fd6707b..e5b02e8581 100644 --- a/tests/integration/fixtures/signoz.py +++ b/tests/integration/fixtures/signoz.py @@ -1,3 +1,4 @@ +import dataclasses import platform import time from http import HTTPStatus @@ -8,20 +9,44 @@ from testcontainers.core.container import DockerContainer, Network from testcontainers.core.image import DockerImage from fixtures import types +from fixtures.logger import setup_logger + +logger = setup_logger(__name__) @pytest.fixture(name="signoz", scope="package") def signoz( network: Network, - zeus: types.TestContainerWiremock, + zeus: types.TestContainerDocker, sqlstore: types.TestContainerSQL, clickhouse: types.TestContainerClickhouse, request: pytest.FixtureRequest, + pytestconfig: pytest.Config, ) -> types.SigNoz: """ Package-scoped fixture for setting up SigNoz. """ + dev = request.config.getoption("--dev") + if dev: + cached_signoz = pytestconfig.cache.get("signoz.container", None) + if cached_signoz: + self = types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + cached_signoz["host_config"]["scheme"], + cached_signoz["host_config"]["address"], + cached_signoz["host_config"]["port"], + ), + container_config=types.TestContainerUrlConfig( + cached_signoz["container_config"]["scheme"], + cached_signoz["container_config"]["address"], + cached_signoz["container_config"]["port"], + ), + ) + return types.SigNoz( + self=self, sqlstore=sqlstore, telemetrystore=clickhouse, zeus=zeus + ) + # Run the migrations for clickhouse request.getfixturevalue("migrator") @@ -70,7 +95,7 @@ def signoz( container.start() def ready(container: DockerContainer) -> None: - for attempt in range(30): + for attempt in range(5): try: response = requests.get( f"http://{container.get_container_host_ip()}:{container.get_exposed_port(8080)}/api/v1/health", # pylint: disable=line-too-long @@ -78,22 +103,31 @@ def signoz( ) return response.status_code == HTTPStatus.OK except Exception: # pylint: disable=broad-exception-caught - print(f"attempt {attempt} at health check failed") + logger.info( + "Attempt %s at readiness check for SigNoz container %s failed, going to retry ...", # pylint: disable=line-too-long + attempt + 1, + container, + ) time.sleep(2) raise TimeoutError("timeout exceeded while waiting") - ready(container=container) + try: + ready(container=container) + except Exception as e: # pylint: disable=broad-exception-caught + raise e def stop(): - logs = container.get_wrapped_container().logs(tail=100) - print(logs.decode(encoding="utf-8")) - container.stop(delete_volume=True) + if dev: + logger.info("Skipping removal of SigNoz container %s ...", container) + return + else: + logger.info("Removing SigNoz container %s ...", container) + container.stop(delete_volume=True) request.addfinalizer(stop) - return types.SigNoz( + cached_signoz = types.SigNoz( self=types.TestContainerDocker( - container=container, host_config=types.TestContainerUrlConfig( "http", container.get_container_host_ip(), @@ -109,3 +143,10 @@ def signoz( telemetrystore=clickhouse, zeus=zeus, ) + + if dev: + pytestconfig.cache.set( + "signoz.container", dataclasses.asdict(cached_signoz.self) + ) + + return cached_signoz diff --git a/tests/integration/fixtures/sqlite.py b/tests/integration/fixtures/sqlite.py index 2c40dbba4e..2d2be15408 100644 --- a/tests/integration/fixtures/sqlite.py +++ b/tests/integration/fixtures/sqlite.py @@ -11,27 +11,59 @@ ConnectionTuple = namedtuple("ConnectionTuple", "connection config") @pytest.fixture(name="sqlite", scope="package") def sqlite( - tmpfs: Generator[types.LegacyPath, Any, None], request: pytest.FixtureRequest + tmpfs: Generator[types.LegacyPath, Any, None], + request: pytest.FixtureRequest, + pytestconfig: pytest.Config, ) -> types.TestContainerSQL: """ Package-scoped fixture for SQLite. """ + + dev = request.config.getoption("--dev") + if dev: + container = pytestconfig.cache.get("sqlite.container", None) + env = pytestconfig.cache.get("sqlite.env", None) + + if container and env: + assert isinstance(container, dict) + assert isinstance(env, dict) + + return types.TestContainerSQL( + container=types.TestContainerDocker( + host_config=None, + container_config=None, + ), + conn=sqlite3.connect( + env["SIGNOZ_SQLSTORE_SQLITE_PATH"], check_same_thread=False + ), + env=env, + ) + tmpdir = tmpfs("sqlite") path = tmpdir / "signoz.db" connection = sqlite3.connect(path, check_same_thread=False) def stop(): + if dev: + return + connection.close() request.addfinalizer(stop) - return types.TestContainerSQL( - None, - host_config=None, - container_config=None, + cached_sqlite = types.TestContainerSQL( + container=types.TestContainerDocker( + host_config=None, + container_config=None, + ), conn=connection, env={ "SIGNOZ_SQLSTORE_PROVIDER": "sqlite", "SIGNOZ_SQLSTORE_SQLITE_PATH": str(path), }, ) + + if dev: + pytestconfig.cache.set("sqlite.env", cached_sqlite.env) + + return cached_sqlite diff --git a/tests/integration/fixtures/types.py b/tests/integration/fixtures/types.py index 38888bd6bb..a5b9ab128e 100644 --- a/tests/integration/fixtures/types.py +++ b/tests/integration/fixtures/types.py @@ -4,8 +4,6 @@ from urllib.parse import urljoin import py from clickhouse_driver.dbapi import Connection -from testcontainers.core.container import DockerContainer -from wiremock.testing.testcontainer import WireMockContainer LegacyPath = py.path.local @@ -27,29 +25,22 @@ class TestContainerUrlConfig: @dataclass class TestContainerDocker: __test__ = False - container: DockerContainer host_config: TestContainerUrlConfig container_config: TestContainerUrlConfig @dataclass -class TestContainerWiremock(TestContainerDocker): +class TestContainerSQL: __test__ = False - container: WireMockContainer - - -@dataclass -class TestContainerSQL(TestContainerDocker): - __test__ = False - container: DockerContainer + container: TestContainerDocker conn: any env: Dict[str, str] @dataclass -class TestContainerClickhouse(TestContainerDocker): +class TestContainerClickhouse: __test__ = False - container: DockerContainer + container: TestContainerDocker conn: Connection env: Dict[str, str] @@ -60,4 +51,4 @@ class SigNoz: self: TestContainerDocker sqlstore: TestContainerSQL telemetrystore: TestContainerClickhouse - zeus: TestContainerWiremock + zeus: TestContainerDocker diff --git a/tests/integration/fixtures/zookeeper.py b/tests/integration/fixtures/zookeeper.py index f1b3b34ab0..7252582b8c 100644 --- a/tests/integration/fixtures/zookeeper.py +++ b/tests/integration/fixtures/zookeeper.py @@ -1,3 +1,5 @@ +import dataclasses + import pytest from testcontainers.core.container import DockerContainer, Network @@ -6,11 +8,29 @@ from fixtures import types @pytest.fixture(name="zookeeper", scope="package") def zookeeper( - network: Network, request: pytest.FixtureRequest + network: Network, request: pytest.FixtureRequest, pytestconfig: pytest.Config ) -> types.TestContainerDocker: """ Package-scoped fixture for Zookeeper TestContainer. """ + + dev = request.config.getoption("--dev") + if dev: + cached_zookeeper = pytestconfig.cache.get("zookeeper", None) + if cached_zookeeper: + return types.TestContainerDocker( + host_config=types.TestContainerUrlConfig( + cached_zookeeper["host_config"]["scheme"], + cached_zookeeper["host_config"]["address"], + cached_zookeeper["host_config"]["port"], + ), + container_config=types.TestContainerUrlConfig( + cached_zookeeper["container_config"]["scheme"], + cached_zookeeper["container_config"]["address"], + cached_zookeeper["container_config"]["port"], + ), + ) + version = request.config.getoption("--zookeeper-version") container = DockerContainer(image=f"bitnami/zookeeper:{version}") @@ -21,12 +41,14 @@ def zookeeper( container.start() def stop(): + if dev: + return + container.stop(delete_volume=True) request.addfinalizer(stop) - return types.TestContainerDocker( - container=container, + cached_zookeeper = types.TestContainerDocker( host_config=types.TestContainerUrlConfig( "tcp", container.get_container_host_ip(), @@ -38,3 +60,8 @@ def zookeeper( 2181, ), ) + + if dev: + pytestconfig.cache.set("zookeeper", dataclasses.asdict(cached_zookeeper)) + + return cached_zookeeper diff --git a/tests/integration/pyproject.toml b/tests/integration/pyproject.toml index 067d712b9a..95703607cd 100644 --- a/tests/integration/pyproject.toml +++ b/tests/integration/pyproject.toml @@ -27,6 +27,10 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] python_files = "src/**/**.py" +log_cli = true +log_format = "%(asctime)s [%(levelname)s] (%(filename)s:%(lineno)s) %(message)s" +log_date_format = "%Y-%m-%d %H:%M:%S" +addopts = "-ra" [tool.pylint.main] ignore = [".venv"] diff --git a/tests/integration/src/bootstrap/b_register.py b/tests/integration/src/bootstrap/b_register.py index 3e26e05378..a2d3d91fa6 100644 --- a/tests/integration/src/bootstrap/b_register.py +++ b/tests/integration/src/bootstrap/b_register.py @@ -3,9 +3,12 @@ from http import HTTPStatus import requests from fixtures import types +from fixtures.logger import setup_logger + +logger = setup_logger(__name__) -def test_register(signoz: types.SigNoz) -> None: +def test_register(signoz: types.SigNoz, get_jwt_token) -> None: response = requests.get(signoz.self.host_config.get("/api/v1/version"), timeout=2) assert response.status_code == HTTPStatus.OK @@ -16,8 +19,8 @@ def test_register(signoz: types.SigNoz) -> None: json={ "name": "admin", "orgId": "", - "orgName": "", - "email": "admin@admin.com", + "orgName": "integration.test", + "email": "admin@integration.test", "password": "password", }, timeout=2, @@ -28,3 +31,175 @@ def test_register(signoz: types.SigNoz) -> None: assert response.status_code == HTTPStatus.OK assert response.json()["setupCompleted"] is True + + admin_token = get_jwt_token("admin@integration.test", "password") + + response = requests.get( + signoz.self.host_config.get("/api/v1/user"), + timeout=2, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == HTTPStatus.OK + + user_response = response.json() + found_user = next( + (user for user in user_response if user["email"] == "admin@integration.test"), + None, + ) + + assert found_user is not None + assert found_user["role"] == "ADMIN" + + response = requests.get( + signoz.self.host_config.get(f"/api/v1/rbac/role/{found_user["id"]}"), + timeout=2, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == HTTPStatus.OK + assert response.json()["group_name"] == "ADMIN" + + +def test_invite_and_register(signoz: types.SigNoz, get_jwt_token) -> None: + # Generate an invite token for the editor user + response = requests.post( + signoz.self.host_config.get("/api/v1/invite"), + json={"email": "editor@integration.test", "role": "EDITOR"}, + timeout=2, + headers={ + "Authorization": f"Bearer {get_jwt_token("admin@integration.test", "password")}" # pylint: disable=line-too-long + }, + ) + + assert response.status_code == HTTPStatus.OK + + invite_response = response.json() + assert "email" in invite_response + assert "inviteToken" in invite_response + + assert invite_response["email"] == "editor@integration.test" + + # Register the editor user using the invite token + response = requests.post( + signoz.self.host_config.get("/api/v1/register"), + json={ + "email": "editor@integration.test", + "password": "password", + "name": "editor", + "token": f"{invite_response["inviteToken"]}", + }, + timeout=2, + ) + assert response.status_code == HTTPStatus.OK + + # Verify that the invite token has been deleted + response = requests.get( + signoz.self.host_config.get( + f"/api/v1/invite/{invite_response["inviteToken"]}" + ), # pylint: disable=line-too-long + timeout=2, + ) + + assert response.status_code in (HTTPStatus.NOT_FOUND, HTTPStatus.BAD_REQUEST) + + # Verify that an admin endpoint cannot be called by the editor user + response = requests.get( + signoz.self.host_config.get("/api/v1/user"), + timeout=2, + headers={ + "Authorization": f"Bearer {get_jwt_token("editor@integration.test", "password")}" # pylint: disable=line-too-long + }, + ) + + assert response.status_code == HTTPStatus.FORBIDDEN + + # Verify that the editor has been created + response = requests.get( + signoz.self.host_config.get("/api/v1/user"), + timeout=2, + headers={ + "Authorization": f"Bearer {get_jwt_token("admin@integration.test", "password")}" # pylint: disable=line-too-long + }, + ) + + assert response.status_code == HTTPStatus.OK + + user_response = response.json() + found_user = next( + (user for user in user_response if user["email"] == "editor@integration.test"), + None, + ) + + assert found_user is not None + assert found_user["role"] == "EDITOR" + assert found_user["name"] == "editor" + assert found_user["email"] == "editor@integration.test" + + +def test_revoke_invite_and_register(signoz: types.SigNoz, get_jwt_token) -> None: + admin_token = get_jwt_token("admin@integration.test", "password") + # Generate an invite token for the viewer user + response = requests.post( + signoz.self.host_config.get("/api/v1/invite"), + json={"email": "viewer@integration.test", "role": "VIEWER"}, + timeout=2, + headers={ + "Authorization": f"Bearer {admin_token}" # pylint: disable=line-too-long + }, + ) + + assert response.status_code == HTTPStatus.OK + + invite_response = response.json() + assert "email" in invite_response + assert "inviteToken" in invite_response + + response = requests.delete( + signoz.self.host_config.get(f"/api/v1/invite/{invite_response['email']}"), + timeout=2, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == HTTPStatus.OK + + # Try registering the viewer user with the invite token + response = requests.post( + signoz.self.host_config.get("/api/v1/register"), + json={ + "email": "viewer@integration.test", + "password": "password", + "name": "viewer", + "token": f"{invite_response["inviteToken"]}", + }, + timeout=2, + ) + + assert response.status_code in (HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND) + + +def test_self_access(signoz: types.SigNoz, get_jwt_token) -> None: + admin_token = get_jwt_token("admin@integration.test", "password") + + response = requests.get( + signoz.self.host_config.get("/api/v1/user"), + timeout=2, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == HTTPStatus.OK + + user_response = response.json() + found_user = next( + (user for user in user_response if user["email"] == "editor@integration.test"), + None, + ) + + response = requests.get( + signoz.self.host_config.get(f"/api/v1/rbac/role/{found_user['id']}"), + timeout=2, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == HTTPStatus.OK + assert response.json()["group_name"] == "EDITOR" diff --git a/tests/integration/src/bootstrap/c_license.py b/tests/integration/src/bootstrap/c_license.py index be29758955..bcb7f9fc65 100644 --- a/tests/integration/src/bootstrap/c_license.py +++ b/tests/integration/src/bootstrap/c_license.py @@ -14,7 +14,7 @@ from fixtures.types import SigNoz def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None: make_http_mocks( - signoz.zeus.container, + signoz.zeus, [ Mapping( request=MappingRequest( @@ -51,7 +51,7 @@ def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None: ], ) - access_token = get_jwt_token("admin@admin.com", "password") + access_token = get_jwt_token("admin@integration.test", "password") response = requests.post( url=signoz.self.host_config.get("/api/v3/licenses"), diff --git a/tests/integration/src/bootstrap/d_apikey.py b/tests/integration/src/bootstrap/d_apikey.py new file mode 100644 index 0000000000..1a27d7b44c --- /dev/null +++ b/tests/integration/src/bootstrap/d_apikey.py @@ -0,0 +1,54 @@ +from fixtures import types +import requests +from http import HTTPStatus + + +def test_api_key(signoz: types.SigNoz, get_jwt_token) -> None: + admin_token = get_jwt_token("admin@integration.test", "password") + + response = requests.post( + signoz.self.host_config.get("/api/v1/pats"), + headers={"Authorization": f"Bearer {admin_token}"}, + json={ + "name": "admin", + "role": "ADMIN", + "expiresInDays": 1, + }, + ) + + assert response.status_code == HTTPStatus.OK + pat_response = response.json() + assert "data" in pat_response + assert "token" in pat_response["data"] + + response = requests.get( + signoz.self.host_config.get("/api/v1/user"), + timeout=2, + headers={"SIGNOZ-API-KEY": f"{pat_response["data"]["token"]}"}, + ) + + assert response.status_code == HTTPStatus.OK + + user_response = response.json() + found_user = next( + (user for user in user_response if user["email"] == "admin@integration.test"), + None, + ) + + response = requests.get( + signoz.self.host_config.get("/api/v1/pats"), + headers={"SIGNOZ-API-KEY": f"{pat_response["data"]["token"]}"}, + ) + + assert response.status_code == HTTPStatus.OK + assert "data" in response.json() + + found_pat = next( + (pat for pat in response.json()["data"] if pat["userId"] == found_user["id"]), + None, + ) + + assert found_pat is not None + assert found_pat["userId"] == found_user["id"] + assert found_pat["name"] == "admin" + assert found_pat["role"] == "ADMIN"