feat(auth): drop group table (#7672)

### Summary

drop group table
This commit is contained in:
Vibhu Pandey 2025-04-26 15:50:02 +05:30 committed by GitHub
parent b60588a749
commit 9e449e2858
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 1762 additions and 1227 deletions

View File

@ -0,0 +1,27 @@
services:
postgres:
image: postgres:15
container_name: postgres
environment:
POSTGRES_DB: signoz
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
healthcheck:
test:
[
"CMD",
"pg_isready",
"-d",
"signoz",
"-U",
"postgres"
]
interval: 30s
timeout: 30s
retries: 3
restart: on-failure
ports:
- "127.0.0.1:5432:5432/tcp"
volumes:
- ${PWD}/fs/tmp/var/lib/postgresql/data/:/var/lib/postgresql/data/

View File

@ -44,10 +44,8 @@ jobs:
- name: run - name: run
run: | run: |
cd tests/integration && \ cd tests/integration && \
poetry run pytest -ra \ poetry run pytest \
--basetemp=./tmp/ \ --basetemp=./tmp/ \
-vv \
--capture=no \
src/${{matrix.src}} \ src/${{matrix.src}} \
--sqlstore-provider ${{matrix.sqlstore-provider}} \ --sqlstore-provider ${{matrix.sqlstore-provider}} \
--postgres-version ${{matrix.postgres-version}} \ --postgres-version ${{matrix.postgres-version}} \

View File

@ -56,6 +56,11 @@ devenv-clickhouse: ## Run clickhouse in devenv
@cd .devenv/docker/clickhouse; \ @cd .devenv/docker/clickhouse; \
docker compose -f compose.yaml up -d docker compose -f compose.yaml up -d
.PHONY: devenv-postgres
devenv-postgres: ## Run postgres in devenv
@cd .devenv/docker/postgres; \
docker compose -f compose.yaml up -d
############################################################## ##############################################################
# go commands # go commands
############################################################## ##############################################################

View File

@ -61,11 +61,17 @@ func (p *Pat) Wrap(next http.Handler) http.Handler {
return return
} }
role, err := authtypes.NewRole(user.Role)
if err != nil {
next.ServeHTTP(w, r)
return
}
jwt := authtypes.Claims{ jwt := authtypes.Claims{
UserID: user.ID, UserID: user.ID,
GroupID: user.GroupID, Role: role,
Email: user.Email, Email: user.Email,
OrgID: user.OrgID, OrgID: user.OrgID,
} }
ctx = authtypes.NewContextWithClaims(ctx, jwt) ctx = authtypes.NewContextWithClaims(ctx, jwt)

View File

@ -12,6 +12,7 @@ import (
"github.com/SigNoz/signoz/ee/query-service/usage" "github.com/SigNoz/signoz/ee/query-service/usage"
"github.com/SigNoz/signoz/pkg/alertmanager" "github.com/SigNoz/signoz/pkg/alertmanager"
"github.com/SigNoz/signoz/pkg/apis/fields" "github.com/SigNoz/signoz/pkg/apis/fields"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/modules/preference" "github.com/SigNoz/signoz/pkg/modules/preference"
preferencecore "github.com/SigNoz/signoz/pkg/modules/preference/core" preferencecore "github.com/SigNoz/signoz/pkg/modules/preference/core"
@ -126,7 +127,7 @@ func (ah *APIHandler) CheckFeature(f string) bool {
} }
// RegisterRoutes registers routes for this handler on the given router // RegisterRoutes registers routes for this handler on the given router
func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddleware) { func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *middleware.AuthZ) {
// note: add ee override methods first // note: add ee override methods first
// routes available only in ee version // routes available only in ee version
@ -199,7 +200,7 @@ func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddlew
} }
func (ah *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *baseapp.AuthMiddleware) { func (ah *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *middleware.AuthZ) {
ah.APIHandler.RegisterCloudIntegrationsRoutes(router, am) ah.APIHandler.RegisterCloudIntegrationsRoutes(router, am)

View File

@ -12,11 +12,11 @@ import (
"github.com/SigNoz/signoz/ee/query-service/constants" "github.com/SigNoz/signoz/ee/query-service/constants"
eeTypes "github.com/SigNoz/signoz/ee/types" eeTypes "github.com/SigNoz/signoz/ee/types"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/auth"
baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
basemodel "github.com/SigNoz/signoz/pkg/query-service/model" basemodel "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"go.uber.org/zap" "go.uber.org/zap"
@ -30,6 +30,12 @@ type CloudIntegrationConnectionParamsResponse struct {
} }
func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseWriter, r *http.Request) {
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
cloudProvider := mux.Vars(r)["cloudProvider"] cloudProvider := mux.Vars(r)["cloudProvider"]
if cloudProvider != "aws" { if cloudProvider != "aws" {
RespondError(w, basemodel.BadRequest(fmt.Errorf( RespondError(w, basemodel.BadRequest(fmt.Errorf(
@ -38,15 +44,7 @@ func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseW
return return
} }
currentUser, err := auth.GetUserFromReqContext(r.Context()) apiKey, apiErr := ah.getOrCreateCloudIntegrationPAT(r.Context(), claims.OrgID, cloudProvider)
if err != nil {
RespondError(w, basemodel.UnauthorizedError(fmt.Errorf(
"couldn't deduce current user: %w", err,
)), nil)
return
}
apiKey, apiErr := ah.getOrCreateCloudIntegrationPAT(r.Context(), currentUser.OrgID, cloudProvider)
if apiErr != nil { if apiErr != nil {
RespondError(w, basemodel.WrapApiError( RespondError(w, basemodel.WrapApiError(
apiErr, "couldn't provision PAT for cloud integration:", apiErr, "couldn't provision PAT for cloud integration:",
@ -137,7 +135,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationPAT(ctx context.Context, orgId
newPAT := eeTypes.NewGettablePAT( newPAT := eeTypes.NewGettablePAT(
integrationPATName, integrationPATName,
baseconstants.ViewerGroup, authtypes.RoleViewer.String(),
integrationUser.ID, integrationUser.ID,
0, 0,
) )
@ -181,11 +179,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationUser(
OrgID: orgId, OrgID: orgId,
} }
viewerGroup, apiErr := dao.DB().GetGroupByName(ctx, baseconstants.ViewerGroup) newUser.Role = authtypes.RoleViewer.String()
if apiErr != nil {
return nil, basemodel.WrapApiError(apiErr, "couldn't get viewer group for creating integration user")
}
newUser.GroupID = viewerGroup.ID
passwordHash, err := auth.PasswordHash(uuid.NewString()) passwordHash, err := auth.PasswordHash(uuid.NewString())
if err != nil { if err != nil {

View File

@ -7,7 +7,6 @@ import (
"github.com/SigNoz/signoz/pkg/errors" "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/app/dashboards" "github.com/SigNoz/signoz/pkg/query-service/app/dashboards"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
@ -36,18 +35,19 @@ func (ah *APIHandler) lockUnlockDashboard(w http.ResponseWriter, r *http.Request
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(w, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(w, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
return return
} }
dashboard, err := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid) dashboard, err := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid)
if err != nil { if err != nil {
render.Error(w, errors.Wrapf(err, errors.TypeInternal, errors.CodeInternal, "failed to get dashboard")) render.Error(w, errors.Wrapf(err, errors.TypeInternal, errors.CodeInternal, "failed to get dashboard"))
return return
} }
if !auth.IsAdminV2(claims) && (dashboard.CreatedBy != claims.Email) { if err := claims.IsAdmin(); err != nil && (dashboard.CreatedBy != claims.Email) {
render.Error(w, errors.Newf(errors.TypeForbidden, errors.CodeForbidden, "You are not authorized to lock/unlock this dashboard")) render.Error(w, errors.Newf(errors.TypeForbidden, errors.CodeForbidden, "You are not authorized to lock/unlock this dashboard"))
return return
} }

View File

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -13,35 +12,31 @@ import (
"github.com/SigNoz/signoz/pkg/errors" "github.com/SigNoz/signoz/pkg/errors"
errorsV2 "github.com/SigNoz/signoz/pkg/errors" errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/auth"
baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants"
basemodel "github.com/SigNoz/signoz/pkg/query-service/model" basemodel "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/valuer" "github.com/SigNoz/signoz/pkg/valuer"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"go.uber.org/zap" "go.uber.org/zap"
) )
func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
req := model.CreatePATRequestBody{} req := model.CreatePATRequestBody{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
RespondError(w, model.BadRequest(err), nil) RespondError(w, model.BadRequest(err), nil)
return return
} }
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
pat := eeTypes.NewGettablePAT( pat := eeTypes.NewGettablePAT(
req.Name, req.Name,
req.Role, req.Role,
user.ID, claims.UserID,
req.ExpiresInDays, req.ExpiresInDays,
) )
err = validatePATRequest(pat) err = validatePATRequest(pat)
@ -52,7 +47,7 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
zap.L().Info("Got Create PAT request", zap.Any("pat", pat)) zap.L().Info("Got Create PAT request", zap.Any("pat", pat))
var apierr basemodel.BaseApiError var apierr basemodel.BaseApiError
if pat, apierr = ah.AppDao().CreatePAT(ctx, user.OrgID, pat); apierr != nil { if pat, apierr = ah.AppDao().CreatePAT(r.Context(), claims.OrgID, pat); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
@ -61,20 +56,28 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
} }
func validatePATRequest(req eeTypes.GettablePAT) error { func validatePATRequest(req eeTypes.GettablePAT) error {
if req.Role == "" || (req.Role != baseconstants.ViewerGroup && req.Role != baseconstants.EditorGroup && req.Role != baseconstants.AdminGroup) { _, err := authtypes.NewRole(req.Role)
return fmt.Errorf("valid role is required") if err != nil {
return err
} }
if req.ExpiresAt < 0 { if req.ExpiresAt < 0 {
return fmt.Errorf("valid expiresAt is required") return fmt.Errorf("valid expiresAt is required")
} }
if req.Name == "" { if req.Name == "" {
return fmt.Errorf("valid name is required") return fmt.Errorf("valid name is required")
} }
return nil return nil
} }
func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
req := eeTypes.GettablePAT{} req := eeTypes.GettablePAT{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -89,24 +92,15 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
return return
} }
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
//get the pat //get the pat
existingPAT, paterr := ah.AppDao().GetPATByID(ctx, user.OrgID, id) existingPAT, paterr := ah.AppDao().GetPATByID(r.Context(), claims.OrgID, id)
if paterr != nil { if paterr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error())) render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error()))
return return
} }
// get the user // get the user
createdByUser, usererr := ah.AppDao().GetUser(ctx, existingPAT.UserID) createdByUser, usererr := ah.AppDao().GetUser(r.Context(), existingPAT.UserID)
if usererr != nil { if usererr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error())) render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error()))
return return
@ -123,11 +117,11 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
return return
} }
req.UpdatedByUserID = user.ID req.UpdatedByUserID = claims.UserID
req.UpdatedAt = time.Now() req.UpdatedAt = time.Now()
zap.L().Info("Got Update PAT request", zap.Any("pat", req)) zap.L().Info("Got Update PAT request", zap.Any("pat", req))
var apierr basemodel.BaseApiError var apierr basemodel.BaseApiError
if apierr = ah.AppDao().UpdatePAT(ctx, user.OrgID, req, id); apierr != nil { if apierr = ah.AppDao().UpdatePAT(r.Context(), claims.OrgID, req, id); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
@ -136,50 +130,44 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
} }
func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() claims, err := authtypes.ClaimsFromContext(r.Context())
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil { if err != nil {
RespondError(w, &model.ApiError{ render.Error(w, err)
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return return
} }
zap.L().Info("Get PATs for user", zap.String("user_id", user.ID))
pats, apierr := ah.AppDao().ListPATs(ctx, user.OrgID) pats, apierr := ah.AppDao().ListPATs(r.Context(), claims.OrgID)
if apierr != nil { if apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
ah.Respond(w, pats) ah.Respond(w, pats)
} }
func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
idStr := mux.Vars(r)["id"] idStr := mux.Vars(r)["id"]
id, err := valuer.NewUUID(idStr) id, err := valuer.NewUUID(idStr)
if err != nil { if err != nil {
render.Error(w, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "id is not a valid uuid-v7")) render.Error(w, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "id is not a valid uuid-v7"))
return return
} }
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
//get the pat //get the pat
existingPAT, paterr := ah.AppDao().GetPATByID(ctx, user.OrgID, id) existingPAT, paterr := ah.AppDao().GetPATByID(r.Context(), claims.OrgID, id)
if paterr != nil { if paterr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error())) render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error()))
return return
} }
// get the user // get the user
createdByUser, usererr := ah.AppDao().GetUser(ctx, existingPAT.UserID) createdByUser, usererr := ah.AppDao().GetUser(r.Context(), existingPAT.UserID)
if usererr != nil { if usererr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error())) render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error()))
return return
@ -191,7 +179,7 @@ func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) {
} }
zap.L().Info("Revoke PAT with id", zap.String("id", id.StringValue())) zap.L().Info("Revoke PAT with id", zap.String("id", id.StringValue()))
if apierr := ah.AppDao().RevokePAT(ctx, user.OrgID, id, user.ID); apierr != nil { if apierr := ah.AppDao().RevokePAT(r.Context(), claims.OrgID, id, claims.UserID); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }

View File

@ -2,7 +2,6 @@ package app
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -22,11 +21,9 @@ import (
"github.com/SigNoz/signoz/pkg/alertmanager" "github.com/SigNoz/signoz/pkg/alertmanager"
"github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/prometheus" "github.com/SigNoz/signoz/pkg/prometheus"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/telemetrystore" "github.com/SigNoz/signoz/pkg/telemetrystore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/web" "github.com/SigNoz/signoz/pkg/web"
"github.com/rs/cors" "github.com/rs/cors"
@ -334,24 +331,8 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,
} }
func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*http.Server, error) { func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*http.Server, error) {
r := baseapp.NewRouter() r := baseapp.NewRouter()
am := middleware.NewAuthZ(s.serverOptions.SigNoz.Instrumentation.Logger())
// add auth middleware
getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) {
user, err := auth.GetUserFromReqContext(ctx)
if err != nil {
return nil, err
}
if user.User.OrgID == "" {
return nil, basemodel.UnauthorizedError(errors.New("orgId is missing in the claims"))
}
return user, nil
}
am := baseapp.NewAuthMiddleware(getUserFromRequest)
r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
r.Use(eemiddleware.NewPat(s.serverOptions.SigNoz.SQLStore, []string{"SIGNOZ-API-KEY"}).Wrap) r.Use(eemiddleware.NewPat(s.serverOptions.SigNoz.SQLStore, []string{"SIGNOZ-API-KEY"}).Wrap)

View File

@ -9,7 +9,6 @@ import (
"github.com/SigNoz/signoz/ee/query-service/constants" "github.com/SigNoz/signoz/ee/query-service/constants"
"github.com/SigNoz/signoz/ee/query-service/model" "github.com/SigNoz/signoz/ee/query-service/model"
baseauth "github.com/SigNoz/signoz/pkg/query-service/auth" baseauth "github.com/SigNoz/signoz/pkg/query-service/auth"
baseconst "github.com/SigNoz/signoz/pkg/query-service/constants"
basemodel "github.com/SigNoz/signoz/pkg/query-service/model" basemodel "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/query-service/utils"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
@ -36,12 +35,6 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (
return nil, model.InternalErrorStr("failed to generate password hash") return nil, model.InternalErrorStr("failed to generate password hash")
} }
group, apiErr := m.GetGroupByName(ctx, baseconst.ViewerGroup)
if apiErr != nil {
zap.L().Error("GetGroupByName failed", zap.Error(apiErr))
return nil, apiErr
}
user := &types.User{ user := &types.User{
ID: uuid.New().String(), ID: uuid.New().String(),
Name: "", Name: "",
@ -51,11 +44,11 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, },
ProfilePictureURL: "", // Currently unused ProfilePictureURL: "", // Currently unused
GroupID: group.ID, Role: authtypes.RoleViewer.String(),
OrgID: domain.OrgID, OrgID: domain.OrgID,
} }
user, apiErr = m.CreateUser(ctx, user, false) user, apiErr := m.CreateUser(ctx, user, false)
if apiErr != nil { if apiErr != nil {
zap.L().Error("CreateUser failed", zap.Error(apiErr)) zap.L().Error("CreateUser failed", zap.Error(apiErr))
return nil, apiErr return nil, apiErr
@ -115,7 +108,7 @@ func (m *modelDao) CanUsePassword(ctx context.Context, email string) (bool, base
return false, baseapierr return false, baseapierr
} }
if userPayload.Role != baseconst.AdminGroup { if userPayload.Role != authtypes.RoleAdmin.String() {
return false, model.BadRequest(fmt.Errorf("auth method not supported")) return false, model.BadRequest(fmt.Errorf("auth method not supported"))
} }

View File

@ -239,8 +239,8 @@ func (lm *Manager) ValidateV3(ctx context.Context) (reterr error) {
func (lm *Manager) ActivateV3(ctx context.Context, licenseKey string) (licenseResponse *model.LicenseV3, errResponse *model.ApiError) { func (lm *Manager) ActivateV3(ctx context.Context, licenseKey string) (licenseResponse *model.LicenseV3, errResponse *model.ApiError) {
defer func() { defer func() {
if errResponse != nil { if errResponse != nil {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if ok { if err != nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_LICENSE_ACT_FAILED, telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_LICENSE_ACT_FAILED,
map[string]interface{}{"err": errResponse.Err.Error()}, claims.Email, true, false) map[string]interface{}{"err": errResponse.Err.Error()}, claims.Email, true, false)
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/SigNoz/signoz/pkg/config" "github.com/SigNoz/signoz/pkg/config"
"github.com/SigNoz/signoz/pkg/config/envprovider" "github.com/SigNoz/signoz/pkg/config/envprovider"
"github.com/SigNoz/signoz/pkg/config/fileprovider" "github.com/SigNoz/signoz/pkg/config/fileprovider"
"github.com/SigNoz/signoz/pkg/query-service/auth"
baseconst "github.com/SigNoz/signoz/pkg/query-service/constants" baseconst "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/sqlstore/sqlstorehook" "github.com/SigNoz/signoz/pkg/sqlstore/sqlstorehook"
@ -147,10 +146,6 @@ func main() {
zap.L().Fatal("Could not start server", zap.Error(err)) zap.L().Fatal("Could not start server", zap.Error(err))
} }
if err := auth.InitAuthCache(context.Background()); err != nil {
zap.L().Fatal("Failed to initialize auth cache", zap.Error(err))
}
signoz.Start(context.Background()) signoz.Start(context.Background())
if err := signoz.Wait(context.Background()); err != nil { if err := signoz.Wait(context.Background()); err != nil {

View File

@ -28,10 +28,9 @@ var (
CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE` CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE`
) )
type dialect struct { type dialect struct{}
}
func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
columnType, err := dialect.GetColumnType(ctx, bun, table, column) columnType, err := dialect.GetColumnType(ctx, bun, table, column)
if err != nil { if err != nil {
return err return err
@ -78,7 +77,7 @@ func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB,
return nil return nil
} }
func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
columnExists, err := dialect.ColumnExists(ctx, bun, table, column) columnExists, err := dialect.ColumnExists(ctx, bun, table, column)
if err != nil { if err != nil {
return err return err
@ -420,3 +419,26 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel
return nil return nil
} }
func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bunIDB bun.IDB, model interface{}, column string) error {
existingTable := bunIDB.Dialect().Tables().Get(reflect.TypeOf(model))
columnExists, err := dialect.ColumnExists(ctx, bunIDB, existingTable.Name, column)
if err != nil {
return err
}
if !columnExists {
return nil
}
_, err = bunIDB.
NewDropColumn().
Model(model).
Column(column).
Exec(ctx)
if err != nil {
return err
}
return nil
}

1
frontend/.gitignore vendored
View File

@ -1,3 +1,4 @@
# Sentry Config File # Sentry Config File
.env.sentry-build-plugin .env.sentry-build-plugin
.qodo

View File

@ -23,7 +23,6 @@ function getUserDefaults(): IUser {
organization: '', organization: '',
orgId: '', orgId: '',
role: 'VIEWER', role: 'VIEWER',
groupId: '',
}; };
} }

View File

@ -151,7 +151,6 @@ export function getAppContextMock(
organization: 'Nightswatch', organization: 'Nightswatch',
orgId: 'does-not-matter-id', orgId: 'does-not-matter-id',
role: role as ROLES, role: role as ROLES,
groupId: 'does-not-matter-groupId',
}, },
org: [ org: [
{ {

View File

@ -15,5 +15,4 @@ export interface PayloadProps {
profilePictureURL: string; profilePictureURL: string;
organization: string; organization: string;
role: ROLES; role: ROLES;
groupId: string;
} }

View File

@ -28,9 +28,9 @@ func (api *API) GetAlerts(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -53,9 +53,9 @@ func (api *API) TestReceiver(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -85,9 +85,9 @@ func (api *API) ListChannels(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -122,9 +122,9 @@ func (api *API) GetChannelByID(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -159,9 +159,9 @@ func (api *API) UpdateChannelByID(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -209,9 +209,9 @@ func (api *API) DeleteChannelByID(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -246,9 +246,9 @@ func (api *API) CreateChannel(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second) ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel() defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }

View File

@ -46,8 +46,8 @@ func (a *Analytics) Wrap(next http.Handler) http.Handler {
} }
if _, ok := telemetry.EnabledPaths()[path]; ok { if _, ok := telemetry.EnabledPaths()[path]; ok {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if ok { if err == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false)
} }
} }
@ -134,8 +134,8 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str
data["queryType"] = queryInfoResult.QueryType data["queryType"] = queryInfoResult.QueryType
data["panelType"] = queryInfoResult.PanelType data["panelType"] = queryInfoResult.PanelType
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if ok { if err == nil {
// switch case to set data["screen"] based on the referrer // switch case to set data["screen"] based on the referrer
switch { switch {
case dashboardMatched: case dashboardMatched:

View File

@ -28,9 +28,7 @@ func (a *Auth) Wrap(next http.Handler) http.Handler {
values = append(values, r.Header.Get(header)) values = append(values, r.Header.Get(header))
} }
ctx, err := a.jwt.ContextFromRequest( ctx, err := a.jwt.ContextFromRequest(r.Context(), values...)
r.Context(),
values...)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return

View File

@ -0,0 +1,105 @@
package middleware
import (
"log/slog"
"net/http"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/gorilla/mux"
)
const (
authzDeniedMessage string = "::AUTHZ-DENIED::"
)
type AuthZ struct {
logger *slog.Logger
}
func NewAuthZ(logger *slog.Logger) *AuthZ {
if logger == nil {
panic("cannot build authz middleware, logger is empty")
}
return &AuthZ{logger: logger}
}
func (middleware *AuthZ) ViewAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
if err := claims.IsViewer(); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) EditAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
if err := claims.IsEditor(); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) AdminAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
if err := claims.IsAdmin(); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) SelfAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
id := mux.Vars(req)["id"]
if err := claims.IsSelfAccess(id); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) OpenAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next(rw, req)
})
}

View File

@ -136,8 +136,8 @@ func (middleware *Logging) getLogCommentKVs(r *http.Request) map[string]string {
} }
var email string var email string
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if ok { if err == nil {
email = claims.Email email = claims.Email
} }

View File

@ -58,6 +58,8 @@ func Error(rw http.ResponseWriter, cause error) {
httpCode = http.StatusUnauthorized httpCode = http.StatusUnauthorized
case errors.TypeUnsupported: case errors.TypeUnsupported:
httpCode = http.StatusNotImplemented httpCode = http.StatusNotImplemented
case errors.TypeForbidden:
httpCode = http.StatusForbidden
} }
rea := make([]responseerroradditional, len(a)) rea := make([]responseerroradditional, len(a))

View File

@ -21,11 +21,12 @@ func NewAPI(module organization.Module) organization.API {
} }
func (api *organizationAPI) Get(rw http.ResponseWriter, r *http.Request) { func (api *organizationAPI) Get(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
orgID, err := valuer.NewUUID(claims.OrgID) orgID, err := valuer.NewUUID(claims.OrgID)
if err != nil { if err != nil {
render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id")) render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id"))
@ -52,11 +53,12 @@ func (api *organizationAPI) GetAll(rw http.ResponseWriter, r *http.Request) {
} }
func (api *organizationAPI) Update(rw http.ResponseWriter, r *http.Request) { func (api *organizationAPI) Update(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
orgID, err := valuer.NewUUID(claims.OrgID) orgID, err := valuer.NewUUID(claims.OrgID)
if err != nil { if err != nil {
render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id")) render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id"))

View File

@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/types/preferencetypes" "github.com/SigNoz/signoz/pkg/types/preferencetypes"
@ -30,9 +29,9 @@ func NewAPI(usecase Usecase) API {
func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request) { func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"] preferenceId := mux.Vars(r)["preferenceId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
preference, err := p.usecase.GetOrgPreference( preference, err := p.usecase.GetOrgPreference(
@ -49,18 +48,18 @@ func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request
func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Request) { func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"] preferenceId := mux.Vars(r)["preferenceId"]
req := preferencetypes.UpdatablePreference{} req := preferencetypes.UpdatablePreference{}
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
return
}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
render.Error(rw, err) render.Error(rw, err)
return return
} }
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
render.Error(rw, err)
return
}
err = p.usecase.UpdateOrgPreference(r.Context(), preferenceId, req.PreferenceValue, claims.OrgID) err = p.usecase.UpdateOrgPreference(r.Context(), preferenceId, req.PreferenceValue, claims.OrgID)
if err != nil { if err != nil {
render.Error(rw, err) render.Error(rw, err)
@ -71,9 +70,9 @@ func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Requ
} }
func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Request) { func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
preferences, err := p.usecase.GetAllOrgPreferences( preferences, err := p.usecase.GetAllOrgPreferences(
@ -89,9 +88,9 @@ func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Req
func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Request) { func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"] preferenceId := mux.Vars(r)["preferenceId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
@ -108,14 +107,14 @@ func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Reques
func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Request) { func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"] preferenceId := mux.Vars(r)["preferenceId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
req := preferencetypes.UpdatablePreference{} req := preferencetypes.UpdatablePreference{}
err := json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
render.Error(rw, err) render.Error(rw, err)
@ -131,9 +130,9 @@ func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Req
} }
func (p *preferenceAPI) GetAllUserPreferences(rw http.ResponseWriter, r *http.Request) { func (p *preferenceAPI) GetAllUserPreferences(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, err := authtypes.ClaimsFromContext(r.Context())
if !ok { if err != nil {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(rw, err)
return return
} }
preferences, err := p.usecase.GetAllUserPreferences( preferences, err := p.usecase.GetAllUserPreferences(

View File

@ -1,19 +1,19 @@
package app package app
import ( import (
"errors"
"net/http" "net/http"
"strings" "strings"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/authtypes"
) )
func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("unauthorized"), Typ: model.ErrorUnauthorized}, nil) render.Error(w, errv2)
return return
} }
req, err := parseSetApdexScoreRequest(r) req, err := parseSetApdexScoreRequest(r)
@ -31,9 +31,9 @@ func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) {
func (aH *APIHandler) getApdexSettings(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) getApdexSettings(w http.ResponseWriter, r *http.Request) {
services := r.URL.Query().Get("services") services := r.URL.Query().Get("services")
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("unauthorized"), Typ: model.ErrorUnauthorized}, nil) render.Error(w, errv2)
return return
} }
apdexSet, err := dao.DB().GetApdexSettings(r.Context(), claims.OrgID, strings.Split(strings.TrimSpace(services), ",")) apdexSet, err := dao.DB().GetApdexSettings(r.Context(), claims.OrgID, strings.Split(strings.TrimSpace(services), ","))

View File

@ -1,123 +0,0 @@
package app
import (
"context"
"errors"
"net/http"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types"
"github.com/gorilla/mux"
)
type AuthMiddleware struct {
GetUserFromRequest func(r context.Context) (*types.GettableUser, error)
}
func NewAuthMiddleware(f func(ctx context.Context) (*types.GettableUser, error)) *AuthMiddleware {
return &AuthMiddleware{
GetUserFromRequest: f,
}
}
func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
f(w, r)
}
}
func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !(auth.IsViewer(user) || auth.IsEditor(user) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to viewers/editors/admins"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}
func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !(auth.IsEditor(user) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to editors/admins"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}
func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
id := mux.Vars(r)["id"]
if !(auth.IsSelfAccessRequest(user, id) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible for self access or to the admins"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}
func (am *AuthMiddleware) AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !auth.IsAdmin(user) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to admins only"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}

View File

@ -1037,7 +1037,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con
var serviceNameIntervalMap = map[string][]tracedetail.Interval{} var serviceNameIntervalMap = map[string][]tracedetail.Interval{}
var hasMissingSpans bool var hasMissingSpans bool
claims, claimsPresent := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
cachedTraceData, err := r.GetWaterfallSpansForTraceWithMetadataCache(ctx, traceID) cachedTraceData, err := r.GetWaterfallSpansForTraceWithMetadataCache(ctx, traceID)
if err == nil { if err == nil {
startTime = cachedTraceData.StartTime startTime = cachedTraceData.StartTime
@ -1050,7 +1050,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con
totalErrorSpans = cachedTraceData.TotalErrorSpans totalErrorSpans = cachedTraceData.TotalErrorSpans
hasMissingSpans = cachedTraceData.HasMissingSpans hasMissingSpans = cachedTraceData.HasMissingSpans
if claimsPresent { if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false)
} }
} }
@ -1067,7 +1067,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con
} }
totalSpans = uint64(len(searchScanResponses)) totalSpans = uint64(len(searchScanResponses))
if claimsPresent { if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false)
} }
@ -3280,8 +3280,8 @@ func (r *ClickHouseReader) GetLogs(ctx context.Context, params *model.LogsFilter
"lenFilters": lenFilters, "lenFilters": lenFilters,
} }
if lenFilters != 0 { if lenFilters != 0 {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false)
} }
} }
@ -3322,8 +3322,8 @@ func (r *ClickHouseReader) TailLogs(ctx context.Context, client *model.LogsTailC
"lenFilters": lenFilters, "lenFilters": lenFilters,
} }
if lenFilters != 0 { if lenFilters != 0 {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false)
} }
} }
@ -3414,8 +3414,8 @@ func (r *ClickHouseReader) AggregateLogs(ctx context.Context, params *model.Logs
"lenFilters": lenFilters, "lenFilters": lenFilters,
} }
if lenFilters != 0 { if lenFilters != 0 {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false)
} }
} }
@ -6835,8 +6835,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea
if traceSummary.NumSpans > uint64(params.MaxSpansInTrace) { if traceSummary.NumSpans > uint64(params.MaxSpansInTrace) {
zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace), zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace),
zap.Uint64("Count", traceSummary.NumSpans)) zap.Uint64("Count", traceSummary.NumSpans))
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
data := map[string]interface{}{ data := map[string]interface{}{
"traceSize": traceSummary.NumSpans, "traceSize": traceSummary.NumSpans,
"maxSpansInTraceLimit": params.MaxSpansInTrace, "maxSpansInTraceLimit": params.MaxSpansInTrace,
@ -6847,8 +6847,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea
return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details") return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details")
} }
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
data := map[string]interface{}{ data := map[string]interface{}{
"traceSize": traceSummary.NumSpans, "traceSize": traceSummary.NumSpans,
"algo": "smart", "algo": "smart",
@ -6937,8 +6937,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea
} }
end = time.Now() end = time.Now()
zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start))) zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start)))
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
data := map[string]interface{}{ data := map[string]interface{}{
"traceSize": len(searchScanResponses), "traceSize": len(searchScanResponses),
"spansRenderLimit": params.SpansRenderLimit, "spansRenderLimit": params.SpansRenderLimit,
@ -6976,8 +6976,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc
if countSpans > uint64(params.MaxSpansInTrace) { if countSpans > uint64(params.MaxSpansInTrace) {
zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace), zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace),
zap.Uint64("Count", countSpans)) zap.Uint64("Count", countSpans))
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
data := map[string]interface{}{ data := map[string]interface{}{
"traceSize": countSpans, "traceSize": countSpans,
"maxSpansInTraceLimit": params.MaxSpansInTrace, "maxSpansInTraceLimit": params.MaxSpansInTrace,
@ -6988,8 +6988,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc
return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details") return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details")
} }
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
data := map[string]interface{}{ data := map[string]interface{}{
"traceSize": countSpans, "traceSize": countSpans,
"algo": "smart", "algo": "smart",
@ -7049,8 +7049,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc
} }
end = time.Now() end = time.Now()
zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start))) zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start)))
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if ok { if errv2 == nil {
data := map[string]interface{}{ data := map[string]interface{}{
"traceSize": len(searchScanResponses), "traceSize": len(searchScanResponses),
"spansRenderLimit": params.SpansRenderLimit, "spansRenderLimit": params.SpansRenderLimit,

View File

@ -6,12 +6,11 @@ import (
"github.com/SigNoz/signoz/pkg/modules/organization" "github.com/SigNoz/signoz/pkg/modules/organization"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/query-service/utils"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -300,13 +299,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
return nil, model.InternalError(apiErr)
}
auth.InitAuthCache(ctx)
userId := uuid.NewString() userId := uuid.NewString()
return dao.DB().CreateUser( return dao.DB().CreateUser(
ctx, ctx,
@ -316,7 +308,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
Email: userId[:8] + "test@test.com", Email: userId[:8] + "test@test.com",
Password: "test", Password: "test",
OrgID: organization.ID.StringValue(), OrgID: organization.ID.StringValue(),
GroupID: group.ID, Role: authtypes.RoleAdmin.String(),
}, },
true, true,
) )

View File

@ -108,8 +108,8 @@ func CreateView(ctx context.Context, orgID string, view v3.SavedView) (valuer.UU
createdAt := time.Now() createdAt := time.Now()
updatedAt := time.Now() updatedAt := time.Now()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
return valuer.UUID{}, fmt.Errorf("error in getting email from context") return valuer.UUID{}, fmt.Errorf("error in getting email from context")
} }
@ -177,8 +177,8 @@ func UpdateView(ctx context.Context, orgID string, uuid valuer.UUID, view v3.Sav
return fmt.Errorf("error in marshalling explorer query data: %s", err.Error()) return fmt.Errorf("error in marshalling explorer query data: %s", err.Error())
} }
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
return fmt.Errorf("error in getting email from context") return fmt.Errorf("error in getting email from context")
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/SigNoz/signoz/pkg/alertmanager" "github.com/SigNoz/signoz/pkg/alertmanager"
"github.com/SigNoz/signoz/pkg/apis/fields" "github.com/SigNoz/signoz/pkg/apis/fields"
errorsV2 "github.com/SigNoz/signoz/pkg/errors" errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/modules/organization" "github.com/SigNoz/signoz/pkg/modules/organization"
"github.com/SigNoz/signoz/pkg/modules/preference" "github.com/SigNoz/signoz/pkg/modules/preference"
@ -54,7 +55,6 @@ import (
tracesV4 "github.com/SigNoz/signoz/pkg/query-service/app/traces/v4" tracesV4 "github.com/SigNoz/signoz/pkg/query-service/app/traces/v4"
"github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/cache" "github.com/SigNoz/signoz/pkg/query-service/cache"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/contextlinks" "github.com/SigNoz/signoz/pkg/query-service/contextlinks"
v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3"
"github.com/SigNoz/signoz/pkg/query-service/postprocess" "github.com/SigNoz/signoz/pkg/query-service/postprocess"
@ -405,7 +405,7 @@ func writeHttpResponse(w http.ResponseWriter, data interface{}) {
} }
} }
func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v3").Subrouter() subRouter := router.PathPrefix("/api/v3").Subrouter()
subRouter.HandleFunc("/autocomplete/aggregate_attributes", am.ViewAccess( subRouter.HandleFunc("/autocomplete/aggregate_attributes", am.ViewAccess(
withCacheControl(AutoCompleteCacheControlAge, aH.autocompleteAggregateAttributes))).Methods(http.MethodGet) withCacheControl(AutoCompleteCacheControlAge, aH.autocompleteAggregateAttributes))).Methods(http.MethodGet)
@ -431,14 +431,14 @@ func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMid
subRouter.HandleFunc("/logs/livetail", am.ViewAccess(aH.liveTailLogs)).Methods(http.MethodGet) subRouter.HandleFunc("/logs/livetail", am.ViewAccess(aH.liveTailLogs)).Methods(http.MethodGet)
} }
func (aH *APIHandler) RegisterFieldsRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterFieldsRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1").Subrouter() subRouter := router.PathPrefix("/api/v1").Subrouter()
subRouter.HandleFunc("/fields/keys", am.ViewAccess(aH.FieldsAPI.GetFieldsKeys)).Methods(http.MethodGet) subRouter.HandleFunc("/fields/keys", am.ViewAccess(aH.FieldsAPI.GetFieldsKeys)).Methods(http.MethodGet)
subRouter.HandleFunc("/fields/values", am.ViewAccess(aH.FieldsAPI.GetFieldsValues)).Methods(http.MethodGet) subRouter.HandleFunc("/fields/values", am.ViewAccess(aH.FieldsAPI.GetFieldsValues)).Methods(http.MethodGet)
} }
func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *middleware.AuthZ) {
hostsSubRouter := router.PathPrefix("/api/v1/hosts").Subrouter() hostsSubRouter := router.PathPrefix("/api/v1/hosts").Subrouter()
hostsSubRouter.HandleFunc("/attribute_keys", am.ViewAccess(aH.getHostAttributeKeys)).Methods(http.MethodGet) hostsSubRouter.HandleFunc("/attribute_keys", am.ViewAccess(aH.getHostAttributeKeys)).Methods(http.MethodGet)
hostsSubRouter.HandleFunc("/attribute_values", am.ViewAccess(aH.getHostAttributeValues)).Methods(http.MethodGet) hostsSubRouter.HandleFunc("/attribute_values", am.ViewAccess(aH.getHostAttributeValues)).Methods(http.MethodGet)
@ -498,12 +498,12 @@ func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *AuthMid
infraOnboardingSubRouter.HandleFunc("/k8s/status", am.ViewAccess(aH.getK8sInfraOnboardingStatus)).Methods(http.MethodGet) infraOnboardingSubRouter.HandleFunc("/k8s/status", am.ViewAccess(aH.getK8sInfraOnboardingStatus)).Methods(http.MethodGet)
} }
func (aH *APIHandler) RegisterWebSocketPaths(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterWebSocketPaths(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/ws").Subrouter() subRouter := router.PathPrefix("/ws").Subrouter()
subRouter.HandleFunc("/query_progress", am.ViewAccess(aH.GetQueryProgressUpdates)).Methods(http.MethodGet) subRouter.HandleFunc("/query_progress", am.ViewAccess(aH.GetQueryProgressUpdates)).Methods(http.MethodGet)
} }
func (aH *APIHandler) RegisterQueryRangeV4Routes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterQueryRangeV4Routes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v4").Subrouter() subRouter := router.PathPrefix("/api/v4").Subrouter()
subRouter.HandleFunc("/query_range", am.ViewAccess(aH.QueryRangeV4)).Methods(http.MethodPost) subRouter.HandleFunc("/query_range", am.ViewAccess(aH.QueryRangeV4)).Methods(http.MethodPost)
subRouter.HandleFunc("/metric/metric_metadata", am.ViewAccess(aH.getMetricMetadata)).Methods(http.MethodGet) subRouter.HandleFunc("/metric/metric_metadata", am.ViewAccess(aH.getMetricMetadata)).Methods(http.MethodGet)
@ -520,7 +520,7 @@ func (aH *APIHandler) RegisterPrivateRoutes(router *mux.Router) {
} }
// RegisterRoutes registers routes for this handler on the given router // RegisterRoutes registers routes for this handler on the given router
func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *middleware.AuthZ) {
router.HandleFunc("/api/v1/query_range", am.ViewAccess(aH.queryRangeMetrics)).Methods(http.MethodGet) router.HandleFunc("/api/v1/query_range", am.ViewAccess(aH.queryRangeMetrics)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/query", am.ViewAccess(aH.queryMetrics)).Methods(http.MethodGet) router.HandleFunc("/api/v1/query", am.ViewAccess(aH.queryMetrics)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/channels", am.ViewAccess(aH.AlertmanagerAPI.ListChannels)).Methods(http.MethodGet) router.HandleFunc("/api/v1/channels", am.ViewAccess(aH.AlertmanagerAPI.ListChannels)).Methods(http.MethodGet)
@ -649,7 +649,7 @@ func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) {
})).Methods(http.MethodGet) })).Methods(http.MethodGet)
} }
func (ah *APIHandler) MetricExplorerRoutes(router *mux.Router, am *AuthMiddleware) { func (ah *APIHandler) MetricExplorerRoutes(router *mux.Router, am *middleware.AuthZ) {
router.HandleFunc("/api/v1/metrics/filters/keys", router.HandleFunc("/api/v1/metrics/filters/keys",
am.ViewAccess(ah.FilterKeysSuggestion)). am.ViewAccess(ah.FilterKeysSuggestion)).
Methods(http.MethodGet) Methods(http.MethodGet)
@ -757,9 +757,9 @@ func (aH *APIHandler) PopulateTemporality(ctx context.Context, qp *v3.QueryRange
} }
func (aH *APIHandler) listDowntimeSchedules(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) listDowntimeSchedules(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -1119,10 +1119,9 @@ func (aH *APIHandler) listRules(w http.ResponseWriter, r *http.Request) {
} }
func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) {
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
claims, ok := authtypes.ClaimsFromContext(r.Context()) if errv2 != nil {
if !ok { render.Error(w, errv2)
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
return return
} }
allDashboards, err := dashboards.GetDashboards(r.Context(), claims.OrgID) allDashboards, err := dashboards.GetDashboards(r.Context(), claims.OrgID)
@ -1189,11 +1188,10 @@ func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) {
} }
func (aH *APIHandler) deleteDashboard(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) deleteDashboard(w http.ResponseWriter, r *http.Request) {
uuid := mux.Vars(r)["uuid"] uuid := mux.Vars(r)["uuid"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
err := dashboards.DeleteDashboard(r.Context(), claims.OrgID, uuid) err := dashboards.DeleteDashboard(r.Context(), claims.OrgID, uuid)
@ -1268,7 +1266,6 @@ func (aH *APIHandler) queryDashboardVarsV2(w http.ResponseWriter, r *http.Reques
} }
func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) {
uuid := mux.Vars(r)["uuid"] uuid := mux.Vars(r)["uuid"]
var postData map[string]interface{} var postData map[string]interface{}
@ -1283,9 +1280,9 @@ func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) {
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
dashboard, apiError := dashboards.UpdateDashboard(r.Context(), claims.OrgID, claims.Email, uuid, postData) dashboard, apiError := dashboards.UpdateDashboard(r.Context(), claims.OrgID, claims.Email, uuid, postData)
@ -1302,9 +1299,9 @@ func (aH *APIHandler) getDashboard(w http.ResponseWriter, r *http.Request) {
uuid := mux.Vars(r)["uuid"] uuid := mux.Vars(r)["uuid"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
dashboard, apiError := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid) dashboard, apiError := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid)
@ -1356,9 +1353,9 @@ func (aH *APIHandler) createDashboards(w http.ResponseWriter, r *http.Request) {
RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, "Error reading request body") RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, "Error reading request body")
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
dash, apiErr := dashboards.CreateDashboard(r.Context(), claims.OrgID, claims.Email, postData) dash, apiErr := dashboards.CreateDashboard(r.Context(), claims.OrgID, claims.Email, postData)
@ -1612,20 +1609,19 @@ func (aH *APIHandler) submitFeedback(w http.ResponseWriter, r *http.Request) {
"email": email, "email": email,
"message": message, "message": message,
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if ok { if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_INPRODUCT_FEEDBACK, data, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_INPRODUCT_FEEDBACK, data, claims.Email, true, false)
} }
} }
func (aH *APIHandler) registerEvent(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) registerEvent(w http.ResponseWriter, r *http.Request) {
request, err := parseRegisterEventRequest(r) request, err := parseRegisterEventRequest(r)
if aH.HandleError(w, err, http.StatusBadRequest) { if aH.HandleError(w, err, http.StatusBadRequest) {
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if ok { if errv2 == nil {
switch request.EventType { switch request.EventType {
case model.TrackEvent: case model.TrackEvent:
telemetry.GetInstance().SendEvent(request.EventName, request.Attributes, claims.Email, request.RateLimited, true) telemetry.GetInstance().SendEvent(request.EventName, request.Attributes, claims.Email, request.RateLimited, true)
@ -1737,8 +1733,8 @@ func (aH *APIHandler) getServices(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{ data := map[string]interface{}{
"number": len(*result), "number": len(*result),
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if ok { if errv2 != nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_NUMBER_OF_SERVICES, data, claims.Email, true, false) telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_NUMBER_OF_SERVICES, data, claims.Email, true, false)
} }
@ -1918,8 +1914,8 @@ func (aH *APIHandler) setTTL(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil) RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil)
return return
} }
@ -1946,8 +1942,8 @@ func (aH *APIHandler) getTTL(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil) RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil)
return return
} }
@ -2030,9 +2026,10 @@ func (aH *APIHandler) inviteUser(w http.ResponseWriter, r *http.Request) {
resp, err := auth.Invite(r.Context(), req) resp, err := auth.Invite(r.Context(), req)
if err != nil { if err != nil {
RespondError(w, &model.ApiError{Err: err, Typ: model.ErrorInternal}, nil) render.Error(w, err)
return return
} }
aH.WriteJSON(w, r, resp) aH.WriteJSON(w, r, resp)
} }
@ -2088,11 +2085,10 @@ func (aH *APIHandler) revokeInvite(w http.ResponseWriter, r *http.Request) {
// listPendingInvites is used to list the pending invites. // listPendingInvites is used to list the pending invites.
func (aH *APIHandler) listPendingInvites(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) listPendingInvites(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil) render.Error(w, errv2)
return return
} }
invites, err := dao.DB().GetInvites(ctx, claims.OrgID) invites, err := dao.DB().GetInvites(ctx, claims.OrgID)
@ -2309,18 +2305,13 @@ func (aH *APIHandler) deleteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
adminGroup, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup) adminUsers, apiErr := dao.DB().GetUsersByRole(ctx, authtypes.RoleAdmin)
if apiErr != nil {
RespondError(w, apiErr, "Failed to get admin group")
return
}
adminUsers, apiErr := dao.DB().GetUsersByGroup(ctx, adminGroup.ID)
if apiErr != nil { if apiErr != nil {
RespondError(w, apiErr, "Failed to get admin group users") RespondError(w, apiErr, "Failed to get admin group users")
return return
} }
if user.GroupID == adminGroup.ID && len(adminUsers) == 1 { if user.Role == authtypes.RoleAdmin.String() && len(adminUsers) == 1 {
RespondError(w, &model.ApiError{ RespondError(w, &model.ApiError{
Typ: model.ErrorInternal, Typ: model.ErrorInternal,
Err: errors.New("cannot delete the last admin user")}, nil) Err: errors.New("cannot delete the last admin user")}, nil)
@ -2332,6 +2323,7 @@ func (aH *APIHandler) deleteUser(w http.ResponseWriter, r *http.Request) {
RespondError(w, err, "Failed to delete user") RespondError(w, err, "Failed to delete user")
return return
} }
aH.WriteJSON(w, r, map[string]string{"data": "user deleted successfully"}) aH.WriteJSON(w, r, map[string]string{"data": "user deleted successfully"})
} }
@ -2350,13 +2342,8 @@ func (aH *APIHandler) getRole(w http.ResponseWriter, r *http.Request) {
}, nil) }, nil)
return return
} }
group, err := dao.DB().GetGroup(context.Background(), user.GroupID)
if err != nil {
RespondError(w, err, "Failed to get group")
return
}
aH.WriteJSON(w, r, &model.UserRole{UserId: id, GroupName: group.Name}) aH.WriteJSON(w, r, &model.UserRole{UserId: id, GroupName: user.Role})
} }
func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
@ -2368,14 +2355,9 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
} }
ctx := context.Background() ctx := context.Background()
newGroup, apiErr := dao.DB().GetGroupByName(ctx, req.GroupName) role, err := authtypes.NewRole(req.GroupName)
if apiErr != nil { if err != nil {
RespondError(w, apiErr, "Failed to get user's group") RespondError(w, &model.ApiError{Typ: model.ErrorBadData, Err: errors.New("invalid role")}, nil)
return
}
if newGroup == nil {
RespondError(w, apiErr, "Specified group is not present")
return return
} }
@ -2386,8 +2368,8 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
} }
// Make sure that the request is not demoting the last admin user. // Make sure that the request is not demoting the last admin user.
if user.GroupID == auth.AuthCacheObj.AdminGroupId { if user.Role == authtypes.RoleAdmin.String() {
adminUsers, apiErr := dao.DB().GetUsersByGroup(ctx, auth.AuthCacheObj.AdminGroupId) adminUsers, apiErr := dao.DB().GetUsersByRole(ctx, authtypes.RoleAdmin)
if apiErr != nil { if apiErr != nil {
RespondError(w, apiErr, "Failed to fetch adminUsers") RespondError(w, apiErr, "Failed to fetch adminUsers")
return return
@ -2401,7 +2383,7 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
} }
} }
apiErr = dao.DB().UpdateUserGroup(context.Background(), user.ID, newGroup.ID) apiErr = dao.DB().UpdateUserRole(context.Background(), user.ID, role)
if apiErr != nil { if apiErr != nil {
RespondError(w, apiErr, "Failed to add user to group") RespondError(w, apiErr, "Failed to add user to group")
return return
@ -2525,7 +2507,7 @@ func (aH *APIHandler) WriteJSON(w http.ResponseWriter, r *http.Request, response
} }
// RegisterMessagingQueuesRoutes adds messaging-queues routes // RegisterMessagingQueuesRoutes adds messaging-queues routes
func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *middleware.AuthZ) {
// Main messaging queues router // Main messaging queues router
messagingQueuesRouter := router.PathPrefix("/api/v1/messaging-queues").Subrouter() messagingQueuesRouter := router.PathPrefix("/api/v1/messaging-queues").Subrouter()
@ -2567,7 +2549,7 @@ func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *Auth
} }
// RegisterThirdPartyApiRoutes adds third-party-api integration routes // RegisterThirdPartyApiRoutes adds third-party-api integration routes
func (aH *APIHandler) RegisterThirdPartyApiRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterThirdPartyApiRoutes(router *mux.Router, am *middleware.AuthZ) {
// Main messaging queues router // Main messaging queues router
thirdPartyApiRouter := router.PathPrefix("/api/v1/third-party-apis").Subrouter() thirdPartyApiRouter := router.PathPrefix("/api/v1/third-party-apis").Subrouter()
@ -3494,7 +3476,7 @@ func (aH *APIHandler) getAllOrgPreferences(
} }
// RegisterIntegrationRoutes Registers all Integrations // RegisterIntegrationRoutes Registers all Integrations
func (aH *APIHandler) RegisterIntegrationRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterIntegrationRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1/integrations").Subrouter() subRouter := router.PathPrefix("/api/v1/integrations").Subrouter()
subRouter.HandleFunc( subRouter.HandleFunc(
@ -3526,9 +3508,9 @@ func (aH *APIHandler) ListIntegrations(
for k, values := range r.URL.Query() { for k, values := range r.URL.Query() {
params[k] = values[0] params[k] = values[0]
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3546,9 +3528,9 @@ func (aH *APIHandler) GetIntegration(
w http.ResponseWriter, r *http.Request, w http.ResponseWriter, r *http.Request,
) { ) {
integrationId := mux.Vars(r)["integrationId"] integrationId := mux.Vars(r)["integrationId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
integration, apiErr := aH.IntegrationsController.GetIntegration( integration, apiErr := aH.IntegrationsController.GetIntegration(
@ -3566,9 +3548,9 @@ func (aH *APIHandler) GetIntegrationConnectionStatus(
w http.ResponseWriter, r *http.Request, w http.ResponseWriter, r *http.Request,
) { ) {
integrationId := mux.Vars(r)["integrationId"] integrationId := mux.Vars(r)["integrationId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
isInstalled, apiErr := aH.IntegrationsController.IsIntegrationInstalled( isInstalled, apiErr := aH.IntegrationsController.IsIntegrationInstalled(
@ -3785,9 +3767,9 @@ func (aH *APIHandler) InstallIntegration(
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3813,9 +3795,9 @@ func (aH *APIHandler) UninstallIntegration(
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3829,7 +3811,7 @@ func (aH *APIHandler) UninstallIntegration(
} }
// cloud provider integrations // cloud provider integrations
func (aH *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1/cloud-integrations").Subrouter() subRouter := router.PathPrefix("/api/v1/cloud-integrations").Subrouter()
subRouter.HandleFunc( subRouter.HandleFunc(
@ -3875,9 +3857,9 @@ func (aH *APIHandler) CloudIntegrationsListConnectedAccounts(
) { ) {
cloudProvider := mux.Vars(r)["cloudProvider"] cloudProvider := mux.Vars(r)["cloudProvider"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3903,9 +3885,9 @@ func (aH *APIHandler) CloudIntegrationsGenerateConnectionUrl(
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3927,9 +3909,9 @@ func (aH *APIHandler) CloudIntegrationsGetAccountStatus(
cloudProvider := mux.Vars(r)["cloudProvider"] cloudProvider := mux.Vars(r)["cloudProvider"]
accountId := mux.Vars(r)["accountId"] accountId := mux.Vars(r)["accountId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3955,9 +3937,9 @@ func (aH *APIHandler) CloudIntegrationsAgentCheckIn(
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -3985,9 +3967,9 @@ func (aH *APIHandler) CloudIntegrationsUpdateAccountConfig(
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -4009,9 +3991,9 @@ func (aH *APIHandler) CloudIntegrationsDisconnectAccount(
cloudProvider := mux.Vars(r)["cloudProvider"] cloudProvider := mux.Vars(r)["cloudProvider"]
accountId := mux.Vars(r)["accountId"] accountId := mux.Vars(r)["accountId"]
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -4039,9 +4021,9 @@ func (aH *APIHandler) CloudIntegrationsListServices(
cloudAccountId = &cloudAccountIdQP cloudAccountId = &cloudAccountIdQP
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -4069,9 +4051,9 @@ func (aH *APIHandler) CloudIntegrationsGetServiceDetails(
cloudAccountId = &cloudAccountIdQP cloudAccountId = &cloudAccountIdQP
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -4315,9 +4297,9 @@ func (aH *APIHandler) CloudIntegrationsUpdateServiceConfig(
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -4334,7 +4316,7 @@ func (aH *APIHandler) CloudIntegrationsUpdateServiceConfig(
} }
// logs // logs
func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *AuthMiddleware) { func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1/logs").Subrouter() subRouter := router.PathPrefix("/api/v1/logs").Subrouter()
subRouter.HandleFunc("", am.ViewAccess(aH.getLogs)).Methods(http.MethodGet) subRouter.HandleFunc("", am.ViewAccess(aH.getLogs)).Methods(http.MethodGet)
subRouter.HandleFunc("/tail", am.ViewAccess(aH.tailLogs)).Methods(http.MethodGet) subRouter.HandleFunc("/tail", am.ViewAccess(aH.tailLogs)).Methods(http.MethodGet)
@ -4498,9 +4480,9 @@ func (aH *APIHandler) PreviewLogsPipelinesHandler(w http.ResponseWriter, r *http
} }
func (aH *APIHandler) ListLogsPipelinesHandler(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) ListLogsPipelinesHandler(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
@ -4577,10 +4559,9 @@ func (aH *APIHandler) listLogsPipelinesByVersion(ctx context.Context, orgID stri
} }
func (aH *APIHandler) CreateLogsPipeline(w http.ResponseWriter, r *http.Request) { func (aH *APIHandler) CreateLogsPipeline(w http.ResponseWriter, r *http.Request) {
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
claims, ok := authtypes.ClaimsFromContext(r.Context()) if errv2 != nil {
if !ok { render.Error(w, errv2)
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
return return
} }
@ -4622,11 +4603,12 @@ func (aH *APIHandler) getSavedViews(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name") name := r.URL.Query().Get("name")
category := r.URL.Query().Get("category") category := r.URL.Query().Get("category")
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
queries, err := explorer.GetViewsForFilters(r.Context(), claims.OrgID, sourcePage, name, category) queries, err := explorer.GetViewsForFilters(r.Context(), claims.OrgID, sourcePage, name, category)
if err != nil { if err != nil {
RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, nil) RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, nil)
@ -4648,9 +4630,9 @@ func (aH *APIHandler) createSavedViews(w http.ResponseWriter, r *http.Request) {
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
uuid, err := explorer.CreateView(r.Context(), claims.OrgID, view) uuid, err := explorer.CreateView(r.Context(), claims.OrgID, view)
@ -4670,9 +4652,9 @@ func (aH *APIHandler) getSavedView(w http.ResponseWriter, r *http.Request) {
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
view, err := explorer.GetView(r.Context(), claims.OrgID, viewUUID) view, err := explorer.GetView(r.Context(), claims.OrgID, viewUUID)
@ -4703,9 +4685,9 @@ func (aH *APIHandler) updateSavedView(w http.ResponseWriter, r *http.Request) {
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
err = explorer.UpdateView(r.Context(), claims.OrgID, viewUUID, view) err = explorer.UpdateView(r.Context(), claims.OrgID, viewUUID, view)
@ -4725,9 +4707,9 @@ func (aH *APIHandler) deleteSavedView(w http.ResponseWriter, r *http.Request) {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, err.Error())) render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, err.Error()))
return return
} }
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if !ok { if errv2 != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated")) render.Error(w, errv2)
return return
} }
err = explorer.DeleteView(r.Context(), claims.OrgID, viewUUID) err = explorer.DeleteView(r.Context(), claims.OrgID, viewUUID)
@ -5050,8 +5032,8 @@ func sendQueryResultEvents(r *http.Request, result []*v3.Result, queryRangeParam
if len(result) > 0 && (len(result[0].Series) > 0 || len(result[0].List) > 0) { if len(result) > 0 && (len(result[0].Series) > 0 || len(result[0].List) > 0) {
claims, ok := authtypes.ClaimsFromContext(r.Context()) claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if ok { if errv2 == nil {
queryInfoResult := telemetry.GetInstance().CheckQueryInfo(queryRangeParams) queryInfoResult := telemetry.GetInstance().CheckQueryInfo(queryRangeParams)
if queryInfoResult.LogsUsed || queryInfoResult.MetricsUsed || queryInfoResult.TracesUsed { if queryInfoResult.LogsUsed || queryInfoResult.MetricsUsed || queryInfoResult.TracesUsed {

View File

@ -6,14 +6,13 @@ import (
"testing" "testing"
"github.com/SigNoz/signoz/pkg/modules/organization" "github.com/SigNoz/signoz/pkg/modules/organization"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3"
"github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/query-service/utils"
"github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/types/pipelinetypes" "github.com/SigNoz/signoz/pkg/types/pipelinetypes"
ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes" ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes"
"github.com/google/uuid" "github.com/google/uuid"
@ -42,13 +41,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
return nil, model.InternalError(apiErr)
}
auth.InitAuthCache(ctx)
userId := uuid.NewString() userId := uuid.NewString()
return dao.DB().CreateUser( return dao.DB().CreateUser(
ctx, ctx,
@ -58,7 +50,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
Email: userId[:8] + "test@test.com", Email: userId[:8] + "test@test.com",
Password: "test", Password: "test",
OrgID: organization.ID.StringValue(), OrgID: organization.ID.StringValue(),
GroupID: group.ID, Role: authtypes.RoleAdmin.String(),
}, },
true, true,
) )

View File

@ -54,8 +54,8 @@ func (ic *LogParsingPipelineController) ApplyPipelines(
postable []pipelinetypes.PostablePipeline, postable []pipelinetypes.PostablePipeline,
) (*PipelinesResponse, *model.ApiError) { ) (*PipelinesResponse, *model.ApiError) {
// get user id from context // get user id from context
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
return nil, model.UnauthorizedError(fmt.Errorf("failed to get userId from context")) return nil, model.UnauthorizedError(fmt.Errorf("failed to get userId from context"))
} }

View File

@ -53,8 +53,8 @@ func (r *Repo) insertPipeline(
)) ))
} }
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
return nil, model.UnauthorizedError(fmt.Errorf("failed to get email from context")) return nil, model.UnauthorizedError(fmt.Errorf("failed to get email from context"))
} }

View File

@ -159,9 +159,9 @@ func (receiver *SummaryService) GetMetricsSummary(ctx context.Context, metricNam
g.Go(func() error { g.Go(func() error {
var metricNames []string var metricNames []string
metricNames = append(metricNames, metricName) metricNames = append(metricNames, metricName)
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: errors.New("failed to get claims")} return &model.ApiError{Typ: model.ErrorInternal, Err: errv2}
} }
data, err := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames) data, err := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames)
if err != nil { if err != nil {
@ -332,9 +332,9 @@ func (receiver *SummaryService) GetRelatedMetrics(ctx context.Context, params *m
alertsRelatedData := make(map[string][]metrics_explorer.Alert) alertsRelatedData := make(map[string][]metrics_explorer.Alert)
g.Go(func() error { g.Go(func() error {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, errv2 := authtypes.ClaimsFromContext(ctx)
if !ok { if errv2 != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: errors.New("failed to get claims")} return &model.ApiError{Typ: model.ErrorInternal, Err: errv2}
} }
names, apiError := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames) names, apiError := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames)
if apiError != nil { if apiError != nil {

View File

@ -16,6 +16,7 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/kafka" "github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/kafka"
queues2 "github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/queues" queues2 "github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/queues"
"github.com/SigNoz/signoz/pkg/query-service/app/integrations/thirdPartyApi" "github.com/SigNoz/signoz/pkg/query-service/app/integrations/thirdPartyApi"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/govaluate" "github.com/SigNoz/govaluate"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -27,7 +28,6 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/app/queryBuilder" "github.com/SigNoz/signoz/pkg/query-service/app/queryBuilder"
"github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/common" "github.com/SigNoz/signoz/pkg/query-service/common"
"github.com/SigNoz/signoz/pkg/query-service/constants"
baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants" baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3"
@ -492,14 +492,6 @@ func parseInviteRequest(r *http.Request) (*model.InviteRequest, error) {
return &req, nil return &req, nil
} }
func isValidRole(role string) bool {
switch role {
case constants.AdminGroup, constants.EditorGroup, constants.ViewerGroup:
return true
}
return false
}
func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error) { func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error) {
var req model.BulkInviteRequest var req model.BulkInviteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -520,7 +512,9 @@ func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error)
if req.Users[i].FrontendBaseUrl == "" { if req.Users[i].FrontendBaseUrl == "" {
return nil, fmt.Errorf("frontendBaseUrl is required for each user") return nil, fmt.Errorf("frontendBaseUrl is required for each user")
} }
if !isValidRole(req.Users[i].Role) {
_, err := authtypes.NewRole(req.Users[i].Role)
if err != nil {
return nil, fmt.Errorf("invalid role for user: %s", req.Users[i].Email) return nil, fmt.Errorf("invalid role for user: %s", req.Users[i].Email)
} }
} }

View File

@ -2,7 +2,6 @@ package app
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -30,7 +29,6 @@ import (
"github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/telemetrystore" "github.com/SigNoz/signoz/pkg/telemetrystore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/types/preferencetypes" "github.com/SigNoz/signoz/pkg/types/preferencetypes"
"github.com/SigNoz/signoz/pkg/web" "github.com/SigNoz/signoz/pkg/web"
@ -38,7 +36,6 @@ import (
"github.com/soheilhy/cmux" "github.com/soheilhy/cmux"
"github.com/SigNoz/signoz/pkg/query-service/app/explorer" "github.com/SigNoz/signoz/pkg/query-service/app/explorer"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/cache" "github.com/SigNoz/signoz/pkg/query-service/cache"
"github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
@ -308,21 +305,7 @@ func (s *Server) createPublicServer(api *APIHandler, web web.Web) (*http.Server,
r.Use(middleware.NewAnalytics(zap.L()).Wrap) r.Use(middleware.NewAnalytics(zap.L()).Wrap)
r.Use(middleware.NewLogging(zap.L(), s.serverOptions.Config.APIServer.Logging.ExcludedRoutes).Wrap) r.Use(middleware.NewLogging(zap.L(), s.serverOptions.Config.APIServer.Logging.ExcludedRoutes).Wrap)
// add auth middleware am := middleware.NewAuthZ(s.serverOptions.SigNoz.Instrumentation.Logger())
getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) {
user, err := auth.GetUserFromReqContext(ctx)
if err != nil {
return nil, err
}
if user.User.OrgID == "" {
return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims"))
}
return user, nil
}
am := NewAuthMiddleware(getUserFromRequest)
api.RegisterRoutes(r, am) api.RegisterRoutes(r, am)
api.RegisterLogsRoutes(r, am) api.RegisterLogsRoutes(r, am)

View File

@ -26,18 +26,17 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
type JwtContextKeyType string
const AccessJwtKey JwtContextKeyType = "accessJwt"
const RefreshJwtKey JwtContextKeyType = "refreshJwt"
const ( const (
opaqueTokenSize = 16 opaqueTokenSize = 16
minimumPasswordLength = 8 minimumPasswordLength = 8
) )
var ( var (
ErrorInvalidCreds = fmt.Errorf("invalid credentials") ErrorInvalidCreds = fmt.Errorf("invalid credentials")
ErrorEmptyRequest = errors.New("Empty request")
ErrorInvalidRole = errors.New("Invalid role")
ErrorInvalidInviteToken = errors.New("Invalid invite token")
ErrorAskAdmin = errors.New("An invitation is needed to create an account. Please ask your admin (the person who has first installed SIgNoz) to send an invite.")
) )
type InviteEmailData struct { type InviteEmailData struct {
@ -49,7 +48,10 @@ type InviteEmailData struct {
// The root user should be able to invite people to create account on SigNoz cluster. // The root user should be able to invite people to create account on SigNoz cluster.
func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteResponse, error) { func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteResponse, error) {
zap.L().Debug("Got an invite request for email", zap.String("email", req.Email)) claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
token, err := utils.RandomHex(opaqueTokenSize) token, err := utils.RandomHex(opaqueTokenSize)
if err != nil { if err != nil {
@ -64,11 +66,6 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
if user != nil { if user != nil {
return nil, errors.New("User already exists with the same email") return nil, errors.New("User already exists with the same email")
} }
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("failed to extract OrgID from context")
}
// Check if an invite already exists // Check if an invite already exists
invite, apiErr := dao.DB().GetInviteFromEmail(ctx, req.Email) invite, apiErr := dao.DB().GetInviteFromEmail(ctx, req.Email)
if apiErr != nil { if apiErr != nil {
@ -79,8 +76,9 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
return nil, errors.New("An invite already exists for this email") return nil, errors.New("An invite already exists for this email")
} }
if err := validateInviteRequest(req); err != nil { role, err := authtypes.NewRole(req.Role)
return nil, errors.Wrap(err, "invalid invite request") if err != nil {
return nil, err
} }
au, apiErr := dao.DB().GetUser(ctx, claims.UserID) au, apiErr := dao.DB().GetUser(ctx, claims.UserID)
@ -99,7 +97,7 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
Name: req.Name, Name: req.Name,
Email: req.Email, Email: req.Email,
Token: token, Token: token,
Role: req.Role, Role: role.String(),
OrgID: au.OrgID, OrgID: au.OrgID,
} }
@ -120,6 +118,11 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
} }
func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.BulkInviteResponse, error) { func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.BulkInviteResponse, error) {
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
response := &model.BulkInviteResponse{ response := &model.BulkInviteResponse{
Status: "success", Status: "success",
Summary: model.InviteSummary{TotalInvites: len(req.Users)}, Summary: model.InviteSummary{TotalInvites: len(req.Users)},
@ -127,11 +130,6 @@ func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.Bulk
FailedInvites: []model.FailedInvite{}, FailedInvites: []model.FailedInvite{},
} }
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("failed to extract admin user id")
}
au, apiErr := dao.DB().GetUser(ctx, claims.UserID) au, apiErr := dao.DB().GetUser(ctx, claims.UserID)
if apiErr != nil { if apiErr != nil {
return nil, errors.Wrap(apiErr.Err, "failed to query admin user from the DB") return nil, errors.Wrap(apiErr.Err, "failed to query admin user from the DB")
@ -191,8 +189,9 @@ func inviteUser(ctx context.Context, req *model.InviteRequest, au *types.Gettabl
return nil, errors.New("An invite already exists for this email") return nil, errors.New("An invite already exists for this email")
} }
if err := validateInviteRequest(req); err != nil { role, err := authtypes.NewRole(req.Role)
return nil, errors.Wrap(err, "invalid invite request") if err != nil {
return nil, err
} }
inv := &types.Invite{ inv := &types.Invite{
@ -206,7 +205,7 @@ func inviteUser(ctx context.Context, req *model.InviteRequest, au *types.Gettabl
Name: req.Name, Name: req.Name,
Email: req.Email, Email: req.Email,
Token: token, Token: token,
Role: req.Role, Role: role.String(),
OrgID: au.OrgID, OrgID: au.OrgID,
} }
@ -260,15 +259,9 @@ func inviteEmail(req *model.InviteRequest, au *types.GettableUser, token string)
// RevokeInvite is used to revoke the invitation for the given email. // RevokeInvite is used to revoke the invitation for the given email.
func RevokeInvite(ctx context.Context, email string) error { func RevokeInvite(ctx context.Context, email string) error {
zap.L().Debug("RevokeInvite method invoked for email", zap.String("email", email)) claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
if !isValidEmail(email) { return err
return ErrorInvalidInviteToken
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return errors.New("failed to org id from context")
} }
if err := dao.DB().DeleteInvitation(ctx, claims.OrgID, email); err != nil { if err := dao.DB().DeleteInvitation(ctx, claims.OrgID, email); err != nil {
@ -414,19 +407,12 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest, organizationMo
return nil, model.BadRequest(model.ErrPasswordRequired{}) return nil, model.BadRequest(model.ErrPasswordRequired{})
} }
groupName := constants.AdminGroup
organization := types.NewOrganization(req.OrgDisplayName) organization := types.NewOrganization(req.OrgDisplayName)
err := organizationModule.Create(ctx, organization) err := organizationModule.Create(ctx, organization)
if err != nil { if err != nil {
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
group, apiErr := dao.DB().GetGroupByName(ctx, groupName)
if apiErr != nil {
zap.L().Error("GetGroupByName failed", zap.Error(apiErr.Err))
return nil, apiErr
}
var hash string var hash string
hash, err = PasswordHash(req.Password) hash, err = PasswordHash(req.Password)
if err != nil { if err != nil {
@ -443,7 +429,7 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest, organizationMo
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, },
ProfilePictureURL: "", // Currently unused ProfilePictureURL: "", // Currently unused
GroupID: group.ID, Role: authtypes.RoleAdmin.String(),
OrgID: organization.ID.StringValue(), OrgID: organization.ID.StringValue(),
} }
@ -488,13 +474,7 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b
if invite.Role == "" { if invite.Role == "" {
// if role is not provided, default to viewer // if role is not provided, default to viewer
invite.Role = constants.ViewerGroup invite.Role = authtypes.RoleViewer.String()
}
group, apiErr := dao.DB().GetGroupByName(ctx, invite.Role)
if apiErr != nil {
zap.L().Error("GetGroupByName failed", zap.Error(apiErr.Err))
return nil, model.InternalError(model.ErrSignupFailed{})
} }
var hash string var hash string
@ -523,12 +503,12 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, },
ProfilePictureURL: "", // Currently unused ProfilePictureURL: "", // Currently unused
GroupID: group.ID, Role: invite.Role,
OrgID: invite.OrgID, OrgID: invite.OrgID,
} }
// TODO(Ahsan): Ideally create user and delete invitation should happen in a txn. // TODO(Ahsan): Ideally create user and delete invitation should happen in a txn.
user, apiErr = dao.DB().CreateUser(ctx, user, false) user, apiErr := dao.DB().CreateUser(ctx, user, false)
if apiErr != nil { if apiErr != nil {
zap.L().Error("CreateUser failed", zap.Error(apiErr.Err)) zap.L().Error("CreateUser failed", zap.Error(apiErr.Err))
return nil, apiErr return nil, apiErr
@ -574,8 +554,6 @@ func Register(ctx context.Context, req *RegisterRequest, alertmanager alertmanag
// Login method returns access and refresh tokens on successful login, else it errors out. // Login method returns access and refresh tokens on successful login, else it errors out.
func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT) (*model.LoginResponse, error) { func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT) (*model.LoginResponse, error) {
zap.L().Debug("Login method called for user", zap.String("email", request.Email))
user, err := authenticateLogin(ctx, request, jwt) user, err := authenticateLogin(ctx, request, jwt)
if err != nil { if err != nil {
zap.L().Error("Failed to authenticate login request", zap.Error(err)) zap.L().Error("Failed to authenticate login request", zap.Error(err))
@ -599,18 +577,6 @@ func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT)
}, nil }, nil
} }
func claimsToUserPayload(claims authtypes.Claims) (*types.GettableUser, error) {
user := &types.GettableUser{
User: types.User{
ID: claims.UserID,
GroupID: claims.GroupID,
Email: claims.Email,
OrgID: claims.OrgID,
},
}
return user, nil
}
// authenticateLogin is responsible for querying the DB and validating the credentials. // authenticateLogin is responsible for querying the DB and validating the credentials.
func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authtypes.JWT) (*types.GettableUser, error) { func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authtypes.JWT) (*types.GettableUser, error) {
// If refresh token is valid, then simply authorize the login request. // If refresh token is valid, then simply authorize the login request.
@ -621,13 +587,13 @@ func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authty
return nil, errors.Wrap(err, "failed to parse refresh token") return nil, errors.Wrap(err, "failed to parse refresh token")
} }
if claims.OrgID == "" { user := &types.GettableUser{
return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims")) User: types.User{
} ID: claims.UserID,
Role: claims.Role.String(),
user, err := claimsToUserPayload(claims) Email: claims.Email,
if err != nil { OrgID: claims.OrgID,
return nil, errors.Wrap(err, "failed to convert claims to user payload") },
} }
return user, nil return user, nil
} }
@ -658,18 +624,32 @@ func passwordMatch(hash, password string) bool {
} }
func GenerateJWTForUser(user *types.User, jwt *authtypes.JWT) (model.UserJwtObject, error) { func GenerateJWTForUser(user *types.User, jwt *authtypes.JWT) (model.UserJwtObject, error) {
j := model.UserJwtObject{} role, err := authtypes.NewRole(user.Role)
var err error
j.AccessJwtExpiry = time.Now().Add(jwt.JwtExpiry).Unix()
j.AccessJwt, err = jwt.AccessToken(user.OrgID, user.ID, user.GroupID, user.Email)
if err != nil { if err != nil {
return j, errors.Errorf("failed to encode jwt: %v", err) return model.UserJwtObject{}, err
} }
j.RefreshJwtExpiry = time.Now().Add(jwt.JwtRefresh).Unix() accessJwt, accessClaims, err := jwt.AccessToken(user.OrgID, user.ID, user.Email, role)
j.RefreshJwt, err = jwt.RefreshToken(user.OrgID, user.ID, user.GroupID, user.Email)
if err != nil { if err != nil {
return j, errors.Errorf("failed to encode jwt: %v", err) return model.UserJwtObject{}, err
} }
return j, nil
refreshJwt, refreshClaims, err := jwt.RefreshToken(user.OrgID, user.ID, user.Email, role)
if err != nil {
return model.UserJwtObject{}, err
}
return model.UserJwtObject{
AccessJwt: accessJwt,
RefreshJwt: refreshJwt,
AccessJwtExpiry: accessClaims.ExpiresAt.Unix(),
RefreshJwtExpiry: refreshClaims.ExpiresAt.Unix(),
}, nil
}
func ValidatePassword(password string) error {
if len(password) < minimumPasswordLength {
return errors.Errorf("Password should be atleast %d characters.", minimumPasswordLength)
}
return nil
} }

View File

@ -1,82 +0,0 @@
package auth
import (
"context"
errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/pkg/errors"
)
type Group struct {
GroupID string
GroupName string
}
type AuthCache struct {
AdminGroupId string
EditorGroupId string
ViewerGroupId string
}
var AuthCacheObj AuthCache
// InitAuthCache reads the DB and initialize the auth cache.
func InitAuthCache(ctx context.Context) error {
setGroupId := func(groupName string, dest *string) error {
group, err := dao.DB().GetGroupByName(ctx, groupName)
if err != nil {
return errors.Wrapf(err.Err, "failed to get group %s", groupName)
}
*dest = group.ID
return nil
}
if err := setGroupId(constants.AdminGroup, &AuthCacheObj.AdminGroupId); err != nil {
return err
}
if err := setGroupId(constants.EditorGroup, &AuthCacheObj.EditorGroupId); err != nil {
return err
}
if err := setGroupId(constants.ViewerGroup, &AuthCacheObj.ViewerGroupId); err != nil {
return err
}
return nil
}
func GetUserFromReqContext(ctx context.Context) (*types.GettableUser, error) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errorsV2.New(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, "no claims found in context")
}
user := &types.GettableUser{
User: types.User{
ID: claims.UserID,
GroupID: claims.GroupID,
Email: claims.Email,
OrgID: claims.OrgID,
},
}
return user, nil
}
func IsSelfAccessRequest(user *types.GettableUser, id string) bool { return user.ID == id }
func IsViewer(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.ViewerGroupId }
func IsEditor(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.EditorGroupId }
func IsAdmin(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.AdminGroupId }
func IsAdminV2(claims authtypes.Claims) bool { return claims.GroupID == AuthCacheObj.AdminGroupId }
func ValidatePassword(password string) error {
if len(password) < minimumPasswordLength {
return errors.Errorf("Password should be atleast %d characters.", minimumPasswordLength)
}
return nil
}

View File

@ -1,43 +0,0 @@
package auth
import (
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/pkg/errors"
)
var (
ErrorEmptyRequest = errors.New("Empty request")
ErrorInvalidEmail = errors.New("Invalid email")
ErrorInvalidRole = errors.New("Invalid role")
ErrorInvalidInviteToken = errors.New("Invalid invite token")
ErrorAskAdmin = errors.New("An invitation is needed to create an account. Please ask your admin (the person who has first installed SIgNoz) to send an invite.")
)
func isValidRole(role string) bool {
switch role {
case constants.AdminGroup, constants.EditorGroup, constants.ViewerGroup:
return true
}
return false
}
func validateInviteRequest(req *model.InviteRequest) error {
if req == nil {
return ErrorEmptyRequest
}
if !isValidEmail(req.Email) {
return ErrorInvalidEmail
}
if !isValidRole(req.Role) {
return ErrorInvalidRole
}
return nil
}
// TODO(Ahsan): Implement check on email semantic.
func isValidEmail(email string) bool {
return true
}

View File

@ -1,7 +0,0 @@
package constants
const (
AdminGroup = "ADMIN"
EditorGroup = "EDITOR"
ViewerGroup = "VIEWER"
)

View File

@ -5,6 +5,7 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
) )
type ModelDao interface { type ModelDao interface {
@ -22,13 +23,9 @@ type Queries interface {
GetUsers(ctx context.Context) ([]types.GettableUser, *model.ApiError) GetUsers(ctx context.Context) ([]types.GettableUser, *model.ApiError)
GetUsersWithOpts(ctx context.Context, limit int) ([]types.GettableUser, *model.ApiError) GetUsersWithOpts(ctx context.Context, limit int) ([]types.GettableUser, *model.ApiError)
GetGroup(ctx context.Context, id string) (*types.Group, *model.ApiError)
GetGroupByName(ctx context.Context, name string) (*types.Group, *model.ApiError)
GetGroups(ctx context.Context) ([]types.Group, *model.ApiError)
GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError) GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError)
GetUsersByOrg(ctx context.Context, orgId string) ([]types.GettableUser, *model.ApiError) GetUsersByOrg(ctx context.Context, orgId string) ([]types.GettableUser, *model.ApiError)
GetUsersByGroup(ctx context.Context, groupId string) ([]types.GettableUser, *model.ApiError) GetUsersByRole(ctx context.Context, role authtypes.Role) ([]types.GettableUser, *model.ApiError)
GetApdexSettings(ctx context.Context, orgID string, services []string) ([]types.ApdexSettings, *model.ApiError) GetApdexSettings(ctx context.Context, orgID string, services []string) ([]types.ApdexSettings, *model.ApiError)
@ -43,14 +40,11 @@ type Mutations interface {
EditUser(ctx context.Context, update *types.User) (*types.User, *model.ApiError) EditUser(ctx context.Context, update *types.User) (*types.User, *model.ApiError)
DeleteUser(ctx context.Context, id string) *model.ApiError DeleteUser(ctx context.Context, id string) *model.ApiError
CreateGroup(ctx context.Context, group *types.Group) (*types.Group, *model.ApiError)
DeleteGroup(ctx context.Context, id string) *model.ApiError
CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError
DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError
UpdateUserPassword(ctx context.Context, hash, userId string) *model.ApiError UpdateUserPassword(ctx context.Context, hash, userId string) *model.ApiError
UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError UpdateUserRole(ctx context.Context, userId string, role authtypes.Role) *model.ApiError
SetApdexSettings(ctx context.Context, orgID string, set *types.ApdexSettings) *model.ApiError SetApdexSettings(ctx context.Context, orgID string, set *types.ApdexSettings) *model.ApiError
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"go.uber.org/zap"
) )
type ModelDaoSqlite struct { type ModelDaoSqlite struct {
@ -24,12 +23,8 @@ func InitDB(sqlStore sqlstore.SQLStore) (*ModelDaoSqlite, error) {
if err := mds.initializeOrgPreferences(ctx); err != nil { if err := mds.initializeOrgPreferences(ctx); err != nil {
return nil, err return nil, err
} }
if err := mds.initializeRBAC(ctx); err != nil {
return nil, err
}
telemetry.GetInstance().SetUserCountCallback(mds.GetUserCount) telemetry.GetInstance().SetUserCountCallback(mds.GetUserCount)
telemetry.GetInstance().SetUserRoleCallback(mds.GetUserRole)
telemetry.GetInstance().SetGetUsersCallback(mds.GetUsers) telemetry.GetInstance().SetGetUsersCallback(mds.GetUsers)
return mds, nil return mds, nil
@ -76,42 +71,3 @@ func (mds *ModelDaoSqlite) initializeOrgPreferences(ctx context.Context) error {
return nil return nil
} }
// initializeRBAC creates the ADMIN, EDITOR and VIEWER groups if they are not present.
func (mds *ModelDaoSqlite) initializeRBAC(ctx context.Context) error {
f := func(groupName string) error {
_, err := mds.createGroupIfNotPresent(ctx, groupName)
return errors.Wrap(err, "Failed to create group")
}
if err := f(constants.AdminGroup); err != nil {
return err
}
if err := f(constants.EditorGroup); err != nil {
return err
}
if err := f(constants.ViewerGroup); err != nil {
return err
}
return nil
}
func (mds *ModelDaoSqlite) createGroupIfNotPresent(ctx context.Context,
name string) (*types.Group, error) {
group, err := mds.GetGroupByName(ctx, name)
if err != nil {
return nil, errors.Wrap(err.Err, "Failed to query for root group")
}
if group != nil {
return group, nil
}
zap.L().Debug("group is not found, creating it", zap.String("group_name", name))
group, cErr := mds.CreateGroup(ctx, &types.Group{Name: name})
if cErr != nil {
return nil, cErr.Err
}
return group, nil
}

View File

@ -7,7 +7,7 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/query-service/telemetry" "github.com/SigNoz/signoz/pkg/query-service/telemetry"
"github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types"
"github.com/google/uuid" "github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -163,11 +163,11 @@ func (mds *ModelDaoSqlite) UpdateUserPassword(ctx context.Context, passwordHash,
return nil return nil
} }
func (mds *ModelDaoSqlite) UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError { func (mds *ModelDaoSqlite) UpdateUserRole(ctx context.Context, userId string, role authtypes.Role) *model.ApiError {
_, err := mds.bundb.NewUpdate(). _, err := mds.bundb.NewUpdate().
Model(&types.User{}). Model(&types.User{}).
Set("group_id = ?", groupId). Set("role = ?", role).
Where("id = ?", userId). Where("id = ?", userId).
Exec(ctx) Exec(ctx)
@ -207,10 +207,8 @@ func (mds *ModelDaoSqlite) GetUser(ctx context.Context,
users := []types.GettableUser{} users := []types.GettableUser{}
query := mds.bundb.NewSelect(). query := mds.bundb.NewSelect().
Table("users"). Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("g.name as role"). ColumnExpr("o.name as organization").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Join("JOIN organizations o ON o.id = users.org_id"). Join("JOIN organizations o ON o.id = users.org_id").
Where("users.id = ?", id) Where("users.id = ?", id)
@ -244,10 +242,8 @@ func (mds *ModelDaoSqlite) GetUserByEmail(ctx context.Context,
users := []types.GettableUser{} users := []types.GettableUser{}
query := mds.bundb.NewSelect(). query := mds.bundb.NewSelect().
Table("users"). Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("g.name as role"). ColumnExpr("o.name as organization").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Join("JOIN organizations o ON o.id = users.org_id"). Join("JOIN organizations o ON o.id = users.org_id").
Where("users.email = ?", email) Where("users.email = ?", email)
@ -279,10 +275,9 @@ func (mds *ModelDaoSqlite) GetUsersWithOpts(ctx context.Context, limit int) ([]t
query := mds.bundb.NewSelect(). query := mds.bundb.NewSelect().
Table("users"). Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("g.name as role"). ColumnExpr("users.role as role").
ColumnExpr("o.display_name as organization"). ColumnExpr("o.name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Join("JOIN organizations o ON o.id = users.org_id") Join("JOIN organizations o ON o.id = users.org_id")
if limit > 0 { if limit > 0 {
@ -303,10 +298,9 @@ func (mds *ModelDaoSqlite) GetUsersByOrg(ctx context.Context,
query := mds.bundb.NewSelect(). query := mds.bundb.NewSelect().
Table("users"). Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("g.name as role"). ColumnExpr("users.role as role").
ColumnExpr("o.display_name as organization"). ColumnExpr("o.name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Join("JOIN organizations o ON o.id = users.org_id"). Join("JOIN organizations o ON o.id = users.org_id").
Where("users.org_id = ?", orgId) Where("users.org_id = ?", orgId)
@ -317,19 +311,16 @@ func (mds *ModelDaoSqlite) GetUsersByOrg(ctx context.Context,
return users, nil return users, nil
} }
func (mds *ModelDaoSqlite) GetUsersByGroup(ctx context.Context, func (mds *ModelDaoSqlite) GetUsersByRole(ctx context.Context, role authtypes.Role) ([]types.GettableUser, *model.ApiError) {
groupId string) ([]types.GettableUser, *model.ApiError) {
users := []types.GettableUser{} users := []types.GettableUser{}
query := mds.bundb.NewSelect(). query := mds.bundb.NewSelect().
Table("users"). Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id"). Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("g.name as role"). ColumnExpr("users.role as role").
ColumnExpr("o.display_name as organization"). ColumnExpr("o.name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Join("JOIN organizations o ON o.id = users.org_id"). Join("JOIN organizations o ON o.id = users.org_id").
Where("users.group_id = ?", groupId) Where("users.role = ?", role)
err := query.Scan(ctx, &users) err := query.Scan(ctx, &users)
if err != nil { if err != nil {
@ -338,98 +329,7 @@ func (mds *ModelDaoSqlite) GetUsersByGroup(ctx context.Context,
return users, nil return users, nil
} }
func (mds *ModelDaoSqlite) CreateGroup(ctx context.Context, func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError {
group *types.Group) (*types.Group, *model.ApiError) {
group.ID = uuid.NewString()
if _, err := mds.bundb.NewInsert().
Model(group).
Exec(ctx); err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return group, nil
}
func (mds *ModelDaoSqlite) DeleteGroup(ctx context.Context, id string) *model.ApiError {
_, err := mds.bundb.NewDelete().
Model(&types.Group{}).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return nil
}
func (mds *ModelDaoSqlite) GetGroup(ctx context.Context,
id string) (*types.Group, *model.ApiError) {
groups := []types.Group{}
if err := mds.bundb.NewSelect().
Model(&groups).
Where("id = ?", id).
Scan(ctx); err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
if len(groups) > 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: errors.New("Found multiple groups with same ID."),
}
}
if len(groups) == 0 {
return nil, nil
}
return &groups[0], nil
}
func (mds *ModelDaoSqlite) GetGroupByName(ctx context.Context,
name string) (*types.Group, *model.ApiError) {
groups := []types.Group{}
err := mds.bundb.NewSelect().
Model(&groups).
Where("name = ?", name).
Scan(ctx)
if err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
if len(groups) > 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: errors.New("Found multiple groups with same name"),
}
}
if len(groups) == 0 {
return nil, nil
}
return &groups[0], nil
}
// TODO(nitya): should have org id
func (mds *ModelDaoSqlite) GetGroups(ctx context.Context) ([]types.Group, *model.ApiError) {
groups := []types.Group{}
if err := mds.bundb.NewSelect().
Model(&groups).
Scan(ctx); err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return groups, nil
}
func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context,
req *types.ResetPasswordRequest) *model.ApiError {
if _, err := mds.bundb.NewInsert(). if _, err := mds.bundb.NewInsert().
Model(req). Model(req).
@ -439,8 +339,7 @@ func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context,
return nil return nil
} }
func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context, func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError {
token string) *model.ApiError {
_, err := mds.bundb.NewDelete(). _, err := mds.bundb.NewDelete().
Model(&types.ResetPasswordRequest{}). Model(&types.ResetPasswordRequest{}).
Where("token = ?", token). Where("token = ?", token).
@ -452,8 +351,7 @@ func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context,
return nil return nil
} }
func (mds *ModelDaoSqlite) GetResetPasswordEntry(ctx context.Context, func (mds *ModelDaoSqlite) GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError) {
token string) (*types.ResetPasswordRequest, *model.ApiError) {
entries := []types.ResetPasswordRequest{} entries := []types.ResetPasswordRequest{}
@ -491,14 +389,6 @@ func (mds *ModelDaoSqlite) PrecheckLogin(ctx context.Context, email, sourceUrl s
return resp, nil return resp, nil
} }
func (mds *ModelDaoSqlite) GetUserRole(ctx context.Context, groupId string) (string, error) {
role, err := mds.GetGroup(ctx, groupId)
if err != nil || role == nil {
return "", err
}
return role.Name, nil
}
func (mds *ModelDaoSqlite) GetUserCount(ctx context.Context) (int, error) { func (mds *ModelDaoSqlite) GetUserCount(ctx context.Context) (int, error) {
users, err := mds.GetUsers(ctx) users, err := mds.GetUsers(ctx)
if err != nil { if err != nil {

View File

@ -10,7 +10,6 @@ import (
"github.com/SigNoz/signoz/pkg/config/envprovider" "github.com/SigNoz/signoz/pkg/config/envprovider"
"github.com/SigNoz/signoz/pkg/config/fileprovider" "github.com/SigNoz/signoz/pkg/config/fileprovider"
"github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/signoz" "github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/SigNoz/signoz/pkg/types/authtypes"
@ -139,10 +138,6 @@ func main() {
logger.Fatal("Could not start servers", zap.Error(err)) logger.Fatal("Could not start servers", zap.Error(err))
} }
if err := auth.InitAuthCache(context.Background()); err != nil {
logger.Fatal("Failed to initialize auth cache", zap.Error(err))
}
signoz.Start(context.Background()) signoz.Start(context.Background())
if err := signoz.Wait(context.Background()); err != nil { if err := signoz.Wait(context.Background()); err != nil {

View File

@ -332,9 +332,9 @@ func (m *Manager) Stop(ctx context.Context) {
// EditRuleDefinition writes the rule definition to the // EditRuleDefinition writes the rule definition to the
// datastore and also updates the rule executor // datastore and also updates the rule executor
func (m *Manager) EditRule(ctx context.Context, ruleStr string, idStr string) error { func (m *Manager) EditRule(ctx context.Context, ruleStr string, idStr string) error {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return errors.New("claims not found in context") return err
} }
ruleUUID, err := valuer.NewUUID(idStr) ruleUUID, err := valuer.NewUUID(idStr)
@ -469,9 +469,9 @@ func (m *Manager) DeleteRule(ctx context.Context, idStr string) error {
return fmt.Errorf("delete rule received an rule id in invalid format, must be a valid uuid-v7") return fmt.Errorf("delete rule received an rule id in invalid format, must be a valid uuid-v7")
} }
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return errors.New("claims not found in context") return err
} }
_, err = m.ruleStore.GetStoredRule(ctx, id) _, err = m.ruleStore.GetStoredRule(ctx, id)
@ -523,9 +523,9 @@ func (m *Manager) deleteTask(taskName string) {
// CreateRule stores rule def into db and also // CreateRule stores rule def into db and also
// starts an executor for the rule // starts an executor for the rule
func (m *Manager) CreateRule(ctx context.Context, ruleStr string) (*ruletypes.GettableRule, error) { func (m *Manager) CreateRule(ctx context.Context, ruleStr string) (*ruletypes.GettableRule, error) {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return nil, errors.New("claims not found in context") return nil, err
} }
parsedRule, err := ruletypes.ParsePostableRule([]byte(ruleStr)) parsedRule, err := ruletypes.ParsePostableRule([]byte(ruleStr))
@ -804,9 +804,9 @@ func (m *Manager) ListActiveRules() ([]Rule, error) {
} }
func (m *Manager) ListRuleStates(ctx context.Context) (*ruletypes.GettableRules, error) { func (m *Manager) ListRuleStates(ctx context.Context) (*ruletypes.GettableRules, error) {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return nil, errors.New("claims not found in context") return nil, err
} }
// fetch rules from DB // fetch rules from DB
storedRules, err := m.ruleStore.GetStoredRules(ctx, claims.OrgID) storedRules, err := m.ruleStore.GetStoredRules(ctx, claims.OrgID)
@ -918,9 +918,9 @@ func (m *Manager) syncRuleStateWithTask(ctx context.Context, orgID string, taskN
// - re-deploy or undeploy task as necessary // - re-deploy or undeploy task as necessary
// - update the patched rule in the DB // - update the patched rule in the DB
func (m *Manager) PatchRule(ctx context.Context, ruleStr string, ruleIdStr string) (*ruletypes.GettableRule, error) { func (m *Manager) PatchRule(ctx context.Context, ruleStr string, ruleIdStr string) (*ruletypes.GettableRule, error) {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return nil, errors.New("claims not found in context") return nil, err
} }
ruleID, err := valuer.NewUUID(ruleIdStr) ruleID, err := valuer.NewUUID(ruleIdStr)
@ -1020,9 +1020,9 @@ func (m *Manager) TestNotification(ctx context.Context, ruleStr string) (int, *m
} }
func (m *Manager) GetAlertDetailsForMetricNames(ctx context.Context, metricNames []string) (map[string][]ruletypes.GettableRule, *model.ApiError) { func (m *Manager) GetAlertDetailsForMetricNames(ctx context.Context, metricNames []string) (map[string][]ruletypes.GettableRule, *model.ApiError) {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return nil, &model.ApiError{Typ: model.ErrorExec, Err: errors.New("claims not found in context")} return nil, &model.ApiError{Typ: model.ErrorExec, Err: err}
} }
result := make(map[string][]ruletypes.GettableRule) result := make(map[string][]ruletypes.GettableRule)

View File

@ -206,7 +206,6 @@ type Telemetry struct {
alertsInfoCallback func(ctx context.Context) (*model.AlertsInfo, error) alertsInfoCallback func(ctx context.Context) (*model.AlertsInfo, error)
userCountCallback func(ctx context.Context) (int, error) userCountCallback func(ctx context.Context) (int, error)
userRoleCallback func(ctx context.Context, groupId string) (string, error)
getUsersCallback func(ctx context.Context) ([]types.GettableUser, *model.ApiError) getUsersCallback func(ctx context.Context) ([]types.GettableUser, *model.ApiError)
dashboardsInfoCallback func(ctx context.Context) (*model.DashboardsInfo, error) dashboardsInfoCallback func(ctx context.Context) (*model.DashboardsInfo, error)
savedViewsInfoCallback func(ctx context.Context) (*model.SavedViewsInfo, error) savedViewsInfoCallback func(ctx context.Context) (*model.SavedViewsInfo, error)
@ -220,10 +219,6 @@ func (a *Telemetry) SetUserCountCallback(callback func(ctx context.Context) (int
a.userCountCallback = callback a.userCountCallback = callback
} }
func (a *Telemetry) SetUserRoleCallback(callback func(ctx context.Context, groupId string) (string, error)) {
a.userRoleCallback = callback
}
func (a *Telemetry) SetGetUsersCallback(callback func(ctx context.Context) ([]types.GettableUser, *model.ApiError)) { func (a *Telemetry) SetGetUsersCallback(callback func(ctx context.Context) ([]types.GettableUser, *model.ApiError)) {
a.getUsersCallback = callback a.getUsersCallback = callback
} }
@ -555,21 +550,12 @@ func (a *Telemetry) IdentifyUser(user *types.User) {
if !a.isTelemetryEnabled() || a.isTelemetryAnonymous() { if !a.isTelemetryEnabled() || a.isTelemetryAnonymous() {
return return
} }
// extract user group from user.groupId
role, _ := a.userRoleCallback(context.Background(), user.GroupID)
if a.saasOperator != nil { if a.saasOperator != nil {
if role != "" { _ = a.saasOperator.Enqueue(analytics.Identify{
_ = a.saasOperator.Enqueue(analytics.Identify{ UserId: a.userEmail,
UserId: a.userEmail, Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email).Set("role", user.Role),
Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email).Set("role", role), })
})
} else {
_ = a.saasOperator.Enqueue(analytics.Identify{
UserId: a.userEmail,
Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email),
})
}
_ = a.saasOperator.Enqueue(analytics.Group{ _ = a.saasOperator.Enqueue(analytics.Group{
UserId: a.userEmail, UserId: a.userEmail,

View File

@ -10,9 +10,9 @@ import (
"testing" "testing"
"github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants" "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/featureManager" "github.com/SigNoz/signoz/pkg/query-service/featureManager"
@ -310,7 +310,7 @@ func NewFilterSuggestionsTestBed(t *testing.T) *FilterSuggestionsTestBed {
router := app.NewRouter() router := app.NewRouter()
//add the jwt middleware //add the jwt middleware
router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
am := app.NewAuthMiddleware(auth.GetUserFromReqContext) am := middleware.NewAuthZ(instrumentationtest.New().Logger())
apiHandler.RegisterRoutes(router, am) apiHandler.RegisterRoutes(router, am)
apiHandler.RegisterQueryRangeV3Routes(router, am) apiHandler.RegisterQueryRangeV3Routes(router, am)

View File

@ -11,9 +11,9 @@ import (
"github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest"
"github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations" "github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/featureManager" "github.com/SigNoz/signoz/pkg/query-service/featureManager"
"github.com/SigNoz/signoz/pkg/query-service/utils" "github.com/SigNoz/signoz/pkg/query-service/utils"
@ -373,7 +373,7 @@ func NewCloudIntegrationsTestBed(t *testing.T, testDB sqlstore.SQLStore) *CloudI
router := app.NewRouter() router := app.NewRouter()
router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
am := app.NewAuthMiddleware(auth.GetUserFromReqContext) am := middleware.NewAuthZ(instrumentationtest.New().Logger())
apiHandler.RegisterRoutes(router, am) apiHandler.RegisterRoutes(router, am)
apiHandler.RegisterCloudIntegrationsRoutes(router, am) apiHandler.RegisterCloudIntegrationsRoutes(router, am)

View File

@ -9,11 +9,11 @@ import (
"time" "time"
"github.com/SigNoz/signoz/pkg/http/middleware" "github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization" "github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations" "github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations"
"github.com/SigNoz/signoz/pkg/query-service/app/integrations" "github.com/SigNoz/signoz/pkg/query-service/app/integrations"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/featureManager" "github.com/SigNoz/signoz/pkg/query-service/featureManager"
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
@ -580,7 +580,7 @@ func NewIntegrationsTestBed(t *testing.T, testDB sqlstore.SQLStore) *Integration
router := app.NewRouter() router := app.NewRouter()
router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap) router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
am := app.NewAuthMiddleware(auth.GetUserFromReqContext) am := middleware.NewAuthZ(instrumentationtest.New().Logger())
apiHandler.RegisterRoutes(router, am) apiHandler.RegisterRoutes(router, am)
apiHandler.RegisterIntegrationRoutes(router, am) apiHandler.RegisterIntegrationRoutes(router, am)

View File

@ -20,7 +20,6 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/app" "github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/app/clickhouseReader" "github.com/SigNoz/signoz/pkg/query-service/app/clickhouseReader"
"github.com/SigNoz/signoz/pkg/query-service/auth" "github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao" "github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model" "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/sqlstore"
@ -158,13 +157,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
return nil, model.InternalError(err) return nil, model.InternalError(err)
} }
group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
return nil, model.InternalError(apiErr)
}
auth.InitAuthCache(ctx)
userId := uuid.NewString() userId := uuid.NewString()
return dao.DB().CreateUser( return dao.DB().CreateUser(
@ -175,7 +167,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
Email: userId[:8] + "test@test.com", Email: userId[:8] + "test@test.com",
Password: "test", Password: "test",
OrgID: organization.ID.StringValue(), OrgID: organization.ID.StringValue(),
GroupID: group.ID, Role: authtypes.RoleAdmin.String(),
}, },
true, true,
) )

View File

@ -58,6 +58,7 @@ func NewTestSqliteDB(t *testing.T) (sqlStore sqlstore.SQLStore, testDBFilePath s
sqlmigration.NewAddVirtualFieldsFactory(), sqlmigration.NewAddVirtualFieldsFactory(),
sqlmigration.NewUpdateIntegrationsFactory(sqlStore), sqlmigration.NewUpdateIntegrationsFactory(sqlStore),
sqlmigration.NewUpdateOrganizationsFactory(sqlStore), sqlmigration.NewUpdateOrganizationsFactory(sqlStore),
sqlmigration.NewDropGroupsFactory(sqlStore),
), ),
) )
if err != nil { if err != nil {

View File

@ -2,7 +2,6 @@ package sqlrulestore
import ( import (
"context" "context"
"errors"
"time" "time"
"github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/sqlstore"
@ -60,10 +59,9 @@ func (r *maintenance) GetPlannedMaintenanceByID(ctx context.Context, id valuer.U
} }
func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance) (valuer.UUID, error) { func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance) (valuer.UUID, error) {
claims, err := authtypes.ClaimsFromContext(ctx)
claims, ok := authtypes.ClaimsFromContext(ctx) if err != nil {
if !ok { return valuer.UUID{}, err
return valuer.UUID{}, errors.New("no claims found in context")
} }
storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{ storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{
@ -100,7 +98,7 @@ func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance
}) })
} }
err := r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error { err = r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
_, err := r.sqlstore. _, err := r.sqlstore.
BunDBCtx(ctx). BunDBCtx(ctx).
NewInsert(). NewInsert().
@ -147,9 +145,9 @@ func (r *maintenance) DeletePlannedMaintenance(ctx context.Context, id valuer.UU
} }
func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance, id valuer.UUID) error { func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance, id valuer.UUID) error {
claims, ok := authtypes.ClaimsFromContext(ctx) claims, err := authtypes.ClaimsFromContext(ctx)
if !ok { if err != nil {
return errors.New("no claims found in context") return err
} }
storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{ storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{
@ -186,7 +184,7 @@ func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ru
}) })
} }
err := r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error { err = r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
_, err := r.sqlstore. _, err := r.sqlstore.
BunDBCtx(ctx). BunDBCtx(ctx).
NewUpdate(). NewUpdate().

View File

@ -73,6 +73,7 @@ func NewSQLMigrationProviderFactories(sqlstore sqlstore.SQLStore) factory.NamedM
sqlmigration.NewAddVirtualFieldsFactory(), sqlmigration.NewAddVirtualFieldsFactory(),
sqlmigration.NewUpdateIntegrationsFactory(sqlstore), sqlmigration.NewUpdateIntegrationsFactory(sqlstore),
sqlmigration.NewUpdateOrganizationsFactory(sqlstore), sqlmigration.NewUpdateOrganizationsFactory(sqlstore),
sqlmigration.NewDropGroupsFactory(sqlstore),
) )
} }

View File

@ -19,12 +19,13 @@ import (
type SigNoz struct { type SigNoz struct {
*factory.Registry *factory.Registry
Cache cache.Cache Instrumentation instrumentation.Instrumentation
Web web.Web Cache cache.Cache
SQLStore sqlstore.SQLStore Web web.Web
TelemetryStore telemetrystore.TelemetryStore SQLStore sqlstore.SQLStore
Prometheus prometheus.Prometheus TelemetryStore telemetrystore.TelemetryStore
Alertmanager alertmanager.Alertmanager Prometheus prometheus.Prometheus
Alertmanager alertmanager.Alertmanager
} }
func New( func New(
@ -144,12 +145,13 @@ func New(
} }
return &SigNoz{ return &SigNoz{
Registry: registry, Registry: registry,
Cache: cache, Instrumentation: instrumentation,
Web: web, Cache: cache,
SQLStore: sqlstore, Web: web,
TelemetryStore: telemetrystore, SQLStore: sqlstore,
Prometheus: prometheus, TelemetryStore: telemetrystore,
Alertmanager: alertmanager, Prometheus: prometheus,
Alertmanager: alertmanager,
}, nil }, nil
} }

View File

@ -87,14 +87,14 @@ func (migration *updateOrganization) Up(ctx context.Context, db *bun.DB) error {
// since organizations, users has created_at as integer instead of timestamp // since organizations, users has created_at as integer instead of timestamp
for _, table := range []string{"organizations", "users", "invites"} { for _, table := range []string{"organizations", "users", "invites"} {
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil { if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "created_at"); err != nil {
return err return err
} }
} }
// migrate is_anonymous and has_opted_updates to boolean from int // migrate is_anonymous and has_opted_updates to boolean from int
for _, column := range []string{"is_anonymous", "has_opted_updates"} { for _, column := range []string{"is_anonymous", "has_opted_updates"} {
if err := migration.store.Dialect().MigrateIntToBoolean(ctx, tx, "organizations", column); err != nil { if err := migration.store.Dialect().IntToBoolean(ctx, tx, "organizations", column); err != nil {
return err return err
} }
} }

View File

@ -66,16 +66,16 @@ func (migration *updatePatAndOrgDomains) Up(ctx context.Context, db *bun.DB) err
} }
} }
if err := updateOrgId(ctx, tx, "org_domains"); err != nil { if err := updateOrgId(ctx, tx); err != nil {
return err return err
} }
// change created_at and updated_at from integer to timestamp // change created_at and updated_at from integer to timestamp
for _, table := range []string{"personal_access_tokens", "org_domains"} { for _, table := range []string{"personal_access_tokens", "org_domains"} {
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil { if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "created_at"); err != nil {
return err return err
} }
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "updated_at"); err != nil { if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "updated_at"); err != nil {
return err return err
} }
} }
@ -96,7 +96,7 @@ func (migration *updatePatAndOrgDomains) Down(ctx context.Context, db *bun.DB) e
return nil return nil
} }
func updateOrgId(ctx context.Context, tx bun.Tx, table string) error { func updateOrgId(ctx context.Context, tx bun.Tx) error {
if _, err := tx.NewCreateTable(). if _, err := tx.NewCreateTable().
Model(&struct { Model(&struct {
bun.BaseModel `bun:"table:org_domains_new"` bun.BaseModel `bun:"table:org_domains_new"`

View File

@ -359,7 +359,20 @@ func (migration *updateIntegrations) CopyOldCloudIntegrationServicesToNewCloudIn
} }
func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID string) error { func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID string) error {
user := &types.User{} type oldUser struct {
bun.BaseModel `bun:"table:users"`
types.TimeAuditable
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull" json:"name"`
Email string `bun:"email,type:text,notnull,unique" json:"email"`
Password string `bun:"password,type:text,notnull" json:"-"`
ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"`
GroupID string `bun:"group_id,type:text,notnull" json:"groupId"`
OrgID string `bun:"org_id,type:text,notnull" json:"orgId"`
}
user := &oldUser{}
err := tx.NewSelect().Model(user).Where("email = ?", "aws-integration@signoz.io").Scan(context.Background()) err := tx.NewSelect().Model(user).Where("email = ?", "aws-integration@signoz.io").Scan(context.Background())
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -374,7 +387,7 @@ func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID
} }
// new user // new user
newUser := &types.User{ newUser := &oldUser{
ID: uuid.New().String(), ID: uuid.New().String(),
TimeAuditable: types.TimeAuditable{ TimeAuditable: types.TimeAuditable{
CreatedAt: time.Now(), CreatedAt: time.Now(),

View File

@ -0,0 +1,141 @@
package sqlmigration
import (
"context"
"github.com/SigNoz/signoz/pkg/factory"
"github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
)
type dropGroups struct {
sqlstore sqlstore.SQLStore
}
func NewDropGroupsFactory(sqlstore sqlstore.SQLStore) factory.ProviderFactory[SQLMigration, Config] {
return factory.NewProviderFactory(factory.MustNewName("drop_groups"), func(ctx context.Context, providerSettings factory.ProviderSettings, config Config) (SQLMigration, error) {
return newDropGroups(ctx, providerSettings, config, sqlstore)
})
}
func newDropGroups(_ context.Context, _ factory.ProviderSettings, _ Config, sqlstore sqlstore.SQLStore) (SQLMigration, error) {
return &dropGroups{sqlstore: sqlstore}, nil
}
func (migration *dropGroups) Register(migrations *migrate.Migrations) error {
if err := migrations.Register(migration.Up, migration.Down); err != nil {
return err
}
return nil
}
func (migration *dropGroups) Up(ctx context.Context, db *bun.DB) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
type Group struct {
bun.BaseModel `bun:"table:groups"`
types.TimeAuditable
OrgID string `bun:"org_id,type:text"`
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull,unique" json:"name"`
}
type existingUser struct {
bun.BaseModel `bun:"table:users"`
types.TimeAuditable
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull" json:"name"`
Email string `bun:"email,type:text,notnull,unique" json:"email"`
Password string `bun:"password,type:text,notnull" json:"-"`
ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"`
GroupID string `bun:"group_id,type:text,notnull" json:"groupId"`
OrgID string `bun:"org_id,type:text,notnull" json:"orgId"`
}
var existingUsers []*existingUser
if err := tx.
NewSelect().
Model(&existingUsers).
Scan(ctx); err != nil {
return err
}
var groups []*Group
if err := tx.
NewSelect().
Model(&groups).
Scan(ctx); err != nil {
return err
}
groupIDToRoleMap := make(map[string]string)
for _, group := range groups {
groupIDToRoleMap[group.ID] = group.Name
}
roleToUserIDMap := make(map[string][]string)
for _, user := range existingUsers {
roleToUserIDMap[groupIDToRoleMap[user.GroupID]] = append(roleToUserIDMap[groupIDToRoleMap[user.GroupID]], user.ID)
}
if err := migration.sqlstore.Dialect().DropColumnWithForeignKeyConstraint(ctx, tx, new(struct {
bun.BaseModel `bun:"table:users"`
types.TimeAuditable
ID string `bun:"id,pk,type:text"`
Name string `bun:"name,type:text,notnull"`
Email string `bun:"email,type:text,notnull,unique"`
Password string `bun:"password,type:text,notnull"`
ProfilePictureURL string `bun:"profile_picture_url,type:text"`
OrgID string `bun:"org_id,type:text,notnull"`
}), "group_id"); err != nil {
return err
}
if err := migration.sqlstore.Dialect().AddColumn(ctx, tx, "users", "role", "TEXT"); err != nil {
return err
}
for role, userIDs := range roleToUserIDMap {
if _, err := tx.
NewUpdate().
Table("users").
Set("role = ?", role).
Where("id IN (?)", bun.In(userIDs)).
Exec(ctx); err != nil {
return err
}
}
if err := migration.sqlstore.Dialect().AddNotNullDefaultToColumn(ctx, tx, "users", "role", "TEXT", "'VIEWER'"); err != nil {
return err
}
if _, err := tx.
NewDropTable().
Table("groups").
IfExists().
Exec(ctx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (migration *dropGroups) Down(ctx context.Context, db *bun.DB) error {
return nil
}

View File

@ -6,7 +6,6 @@ import (
"github.com/SigNoz/signoz/pkg/factory" "github.com/SigNoz/signoz/pkg/factory"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/migrate" "github.com/uptrace/bun/migrate"
) )
@ -66,29 +65,3 @@ func MustNew(
} }
return migrations return migrations
} }
func GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string
var err error
if bun.Dialect().Name() == dialect.SQLite {
err = bun.NewSelect().
ColumnExpr("type").
TableExpr("pragma_table_info(?)", table).
Where("name = ?", column).
Scan(ctx, &columnType)
} else {
err = bun.NewSelect().
ColumnExpr("data_type").
TableExpr("information_schema.columns").
Where("table_name = ?", table).
Where("column_name = ?", column).
Scan(ctx, &columnType)
}
if err != nil {
return "", err
}
return columnType, nil
}

View File

@ -5,33 +5,53 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"slices" "slices"
"strings"
"github.com/SigNoz/signoz/pkg/errors" "github.com/SigNoz/signoz/pkg/errors"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
var ( const (
Identity = "id" Identity string = "id"
Integer = "INTEGER" Integer string = "INTEGER"
Text = "TEXT" Text string = "TEXT"
) )
var ( const (
Org = "org" Org string = "org"
User = "user" User string = "user"
CloudIntegration = "cloud_integration" CloudIntegration string = "cloud_integration"
) )
var ( const (
OrgReference = `("org_id") REFERENCES "organizations" ("id")` OrgReference string = `("org_id") REFERENCES "organizations" ("id")`
UserReference = `("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE` UserReference string = `("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE`
CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE` CloudIntegrationReference string = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE`
) )
type dialect struct { const (
OrgField string = "org_id"
)
type dialect struct{}
func (dialect *dialect) GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string
err := bun.
NewSelect().
ColumnExpr("type").
TableExpr("pragma_table_info(?)", table).
Where("name = ?", column).
Scan(ctx, &columnType)
if err != nil {
return "", err
}
return columnType, nil
} }
func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
columnType, err := dialect.GetColumnType(ctx, bun, table, column) columnType, err := dialect.GetColumnType(ctx, bun, table, column)
if err != nil { if err != nil {
return err return err
@ -73,7 +93,7 @@ func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB,
return nil return nil
} }
func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
columnExists, err := dialect.ColumnExists(ctx, bun, table, column) columnExists, err := dialect.ColumnExists(ctx, bun, table, column)
if err != nil { if err != nil {
return err return err
@ -118,22 +138,6 @@ func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, ta
return nil return nil
} }
func (dialect *dialect) GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string
err := bun.
NewSelect().
ColumnExpr("type").
TableExpr("pragma_table_info(?)", table).
Where("name = ?", column).
Scan(ctx, &columnType)
if err != nil {
return "", err
}
return columnType, nil
}
func (dialect *dialect) ColumnExists(ctx context.Context, bun bun.IDB, table string, column string) (bool, error) { func (dialect *dialect) ColumnExists(ctx context.Context, bun bun.IDB, table string, column string) (bool, error) {
var count int var count int
err := bun.NewSelect(). err := bun.NewSelect().
@ -217,7 +221,6 @@ func (dialect *dialect) DropColumn(ctx context.Context, bun bun.IDB, table strin
} }
func (dialect *dialect) TableExists(ctx context.Context, bun bun.IDB, table interface{}) (bool, error) { func (dialect *dialect) TableExists(ctx context.Context, bun bun.IDB, table interface{}) (bool, error) {
count := 0 count := 0
err := bun. err := bun.
NewSelect(). NewSelect().
@ -423,3 +426,63 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel
return nil return nil
} }
func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bunIDB bun.IDB, model interface{}, column string) error {
existingTable := bunIDB.Dialect().Tables().Get(reflect.TypeOf(model))
columnExists, err := dialect.ColumnExists(ctx, bunIDB, existingTable.Name, column)
if err != nil {
return err
}
if !columnExists {
return nil
}
newTableName := existingTable.Name + "_tmp"
// Create the newTmpTable query
createTableQuery := bunIDB.NewCreateTable().Model(model).ModelTableExpr(newTableName)
var columnNames []string
for _, field := range existingTable.Fields {
if field.Name != column {
columnNames = append(columnNames, string(field.SQLName))
}
if field.Name == OrgField {
createTableQuery = createTableQuery.ForeignKey(OrgReference)
}
}
// Disable foreign keys temporarily
if _, err := bunIDB.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
return err
}
if _, err = createTableQuery.Exec(ctx); err != nil {
return err
}
// Copy data from old table to new table
if _, err := bunIDB.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", newTableName, strings.Join(columnNames, ", "), existingTable.Name)); err != nil {
return err
}
_, err = bunIDB.NewDropTable().Table(existingTable.Name).Exec(ctx)
if err != nil {
return err
}
_, err = bunIDB.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME TO %s", newTableName, existingTable.Name))
if err != nil {
return err
}
// Re-enable foreign keys
if _, err := bunIDB.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
return err
}
return nil
}

View File

@ -37,15 +37,42 @@ type SQLStoreHook interface {
} }
type SQLDialect interface { type SQLDialect interface {
MigrateIntToTimestamp(context.Context, bun.IDB, string, string) error // Returns the type of the column for the given table and column.
MigrateIntToBoolean(context.Context, bun.IDB, string, string) error
AddNotNullDefaultToColumn(context.Context, bun.IDB, string, string, string, string) error
GetColumnType(context.Context, bun.IDB, string, string) (string, error) GetColumnType(context.Context, bun.IDB, string, string) (string, error)
// Migrates an integer column to a timestamp column for the given table and column.
IntToTimestamp(context.Context, bun.IDB, string, string) error
// Migrates an integer column to a boolean column for the given table and column.
IntToBoolean(context.Context, bun.IDB, string, string) error
// Adds a not null default to the given column for the given table, column, columnType and defaultValue.
AddNotNullDefaultToColumn(context.Context, bun.IDB, string, string, string, string) error
// Checks if a column exists in a table for the given table and column.
ColumnExists(context.Context, bun.IDB, string, string) (bool, error) ColumnExists(context.Context, bun.IDB, string, string) (bool, error)
// Adds a column to a table for the given table, column and columnType.
AddColumn(context.Context, bun.IDB, string, string, string) error AddColumn(context.Context, bun.IDB, string, string, string) error
RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error)
// Drops a column from a table for the given table and column.
DropColumn(context.Context, bun.IDB, string, string) error DropColumn(context.Context, bun.IDB, string, string) error
// Renames a column in a table for the given table, old column name and new column name.
RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error)
// Renames a table and modifies the given model for the given table, old model, new model, references and callback. The old model
// and new model must inherit bun.BaseModel.
RenameTableAndModifyModel(context.Context, bun.IDB, interface{}, interface{}, []string, func(context.Context) error) error RenameTableAndModifyModel(context.Context, bun.IDB, interface{}, interface{}, []string, func(context.Context) error) error
// Updates the primary key for the given table, old model, new model, reference and callback. The old model and new model
// must inherit bun.BaseModel.
UpdatePrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error UpdatePrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error
// Adds a primary key to the given table, old model, new model, reference and callback. The old model and new model
// must inherit bun.BaseModel.
AddPrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error AddPrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error
// Drops the column and the associated foreign key constraint for the given table and column.
DropColumnWithForeignKeyConstraint(context.Context, bun.IDB, interface{}, string) error
} }

View File

@ -9,11 +9,11 @@ import (
type dialect struct { type dialect struct {
} }
func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error { func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
return nil return nil
} }
func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error { func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
return nil return nil
} }
@ -56,3 +56,7 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel
func (dialect *dialect) IndexExists(ctx context.Context, bun bun.IDB, table string, index string) (bool, error) { func (dialect *dialect) IndexExists(ctx context.Context, bun bun.IDB, table string, index string) (bool, error) {
return false, nil return false, nil
} }
func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bun bun.IDB, model interface{}, column string) error {
return nil
}

View File

@ -0,0 +1,83 @@
package authtypes
import (
"log/slog"
"slices"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/golang-jwt/jwt/v5"
)
var _ jwt.ClaimsValidator = (*Claims)(nil)
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"id"`
Email string `json:"email"`
Role Role `json:"role"`
OrgID string `json:"orgId"`
}
func (c *Claims) Validate() error {
if c.UserID == "" {
return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "id is required")
}
// The problem is that when the "role" field is missing entirely from the JSON (as opposed to being present but empty), the UnmarshalJSON method for Role isn't called at all.
// The JSON decoder just sets the Role field to its zero value ("").
if c.Role == "" {
return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "role is required")
}
if c.OrgID == "" {
return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "orgId is required")
}
return nil
}
func (c *Claims) LogValue() slog.Value {
return slog.GroupValue(
slog.String("id", c.UserID),
slog.String("email", c.Email),
slog.String("role", c.Role.String()),
slog.String("orgId", c.OrgID),
slog.Time("exp", c.ExpiresAt.Time),
)
}
func (c *Claims) IsViewer() error {
if slices.Contains([]Role{RoleViewer, RoleEditor, RoleAdmin}, c.Role) {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only viewers/editors/admins can access this resource")
}
func (c *Claims) IsEditor() error {
if slices.Contains([]Role{RoleEditor, RoleAdmin}, c.Role) {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only editors/admins can access this resource")
}
func (c *Claims) IsAdmin() error {
if c.Role == RoleAdmin {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only admins can access this resource")
}
func (c *Claims) IsSelfAccess(id string) error {
if c.UserID == id {
return nil
}
if c.Role == RoleAdmin {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only the user/admin can access their own resource")
}

View File

@ -2,24 +2,15 @@ package authtypes
import ( import (
"context" "context"
"errors"
"fmt"
"strings" "strings"
"time" "time"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
type jwtClaimsKey struct{} type jwtClaimsKey struct{}
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"id"`
GroupID string `json:"gid"`
Email string `json:"email"`
OrgID string `json:"orgId"`
}
type JWT struct { type JWT struct {
JwtSecret string JwtSecret string
JwtExpiry time.Duration JwtExpiry time.Duration
@ -34,16 +25,6 @@ func NewJWT(jwtSecret string, jwtExpiry time.Duration, jwtRefresh time.Duration)
} }
} }
func parseBearerAuth(auth string) (string, bool) {
const prefix = "Bearer "
// Case insensitive prefix match
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", false
}
return auth[len(prefix):], true
}
func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context.Context, error) { func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context.Context, error) {
var value string var value string
for _, v := range values { for _, v := range values {
@ -54,7 +35,7 @@ func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context
} }
if value == "" { if value == "" {
return ctx, errors.New("missing Authorization header") return ctx, errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "missing authorization header")
} }
// parse from // parse from
@ -73,24 +54,18 @@ func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context
} }
func (j *JWT) Claims(jwtStr string) (Claims, error) { func (j *JWT) Claims(jwtStr string) (Claims, error) {
token, err := jwt.ParseWithClaims(jwtStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { claims := Claims{}
_, err := jwt.ParseWithClaims(jwtStr, &claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unknown signing algo: %v", token.Header["alg"]) return nil, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unrecognized signing algorithm: %s", token.Method.Alg())
} }
return []byte(j.JwtSecret), nil return []byte(j.JwtSecret), nil
}) })
if err != nil { if err != nil {
return Claims{}, fmt.Errorf("failed to parse jwt token: %w", err) return Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to parse jwt token")
} }
// Type assertion to retrieve claims from the token return claims, nil
userClaims, ok := token.Claims.(*Claims)
if !ok {
return Claims{}, errors.New("failed to retrieve claims from token")
}
return *userClaims, nil
} }
// NewContextWithClaims attaches individual claims to the context. // NewContextWithClaims attaches individual claims to the context.
@ -106,36 +81,62 @@ func (j *JWT) signToken(claims Claims) (string, error) {
} }
// AccessToken creates an access token with the provided claims // AccessToken creates an access token with the provided claims
func (j *JWT) AccessToken(orgId, userId, groupId, email string) (string, error) { func (j *JWT) AccessToken(orgId, userId, email string, role Role) (string, Claims, error) {
claims := Claims{ claims := Claims{
UserID: userId, UserID: userId,
GroupID: groupId, Role: role,
Email: email, Email: email,
OrgID: orgId, OrgID: orgId,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtExpiry)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtExpiry)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
}, },
} }
return j.signToken(claims)
token, err := j.signToken(claims)
if err != nil {
return "", Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to sign token")
}
return token, claims, nil
} }
// RefreshToken creates a refresh token with the provided claims // RefreshToken creates a refresh token with the provided claims
func (j *JWT) RefreshToken(orgId, userId, groupId, email string) (string, error) { func (j *JWT) RefreshToken(orgId, userId, email string, role Role) (string, Claims, error) {
claims := Claims{ claims := Claims{
UserID: userId, UserID: userId,
GroupID: groupId, Role: role,
Email: email, Email: email,
OrgID: orgId, OrgID: orgId,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtRefresh)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtRefresh)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
}, },
} }
return j.signToken(claims)
token, err := j.signToken(claims)
if err != nil {
return "", Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to sign token")
}
return token, claims, nil
} }
func ClaimsFromContext(ctx context.Context) (Claims, bool) { func ClaimsFromContext(ctx context.Context) (Claims, error) {
claims, ok := ctx.Value(jwtClaimsKey{}).(Claims) claims, ok := ctx.Value(jwtClaimsKey{}).(Claims)
return claims, ok if !ok {
return Claims{}, errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")
}
return claims, nil
}
func parseBearerAuth(auth string) (string, bool) {
const prefix = "Bearer "
// Case insensitive prefix match
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", false
}
return auth[len(prefix):], true
} }

View File

@ -4,35 +4,36 @@ import (
"testing" "testing"
"time" "time"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestGetAccessJwt(t *testing.T) { func TestJwtAccessToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour) jwtService := NewJWT("secret", time.Minute, time.Hour)
token, err := jwtService.AccessToken("orgId", "userId", "groupId", "email@example.com") token, _, err := jwtService.AccessToken("orgId", "userId", "email@example.com", RoleAdmin)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
} }
func TestGetRefreshJwt(t *testing.T) { func TestJwtRefreshToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour) jwtService := NewJWT("secret", time.Minute, time.Hour)
token, err := jwtService.RefreshToken("orgId", "userId", "groupId", "email@example.com") token, _, err := jwtService.RefreshToken("orgId", "userId", "email@example.com", RoleAdmin)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
} }
func TestGetJwtClaims(t *testing.T) { func TestJwtClaims(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour) jwtService := NewJWT("secret", time.Minute, time.Hour)
// Create a valid token // Create a valid token
claims := Claims{ claims := Claims{
UserID: "userId", UserID: "userId",
GroupID: "groupId", Role: RoleAdmin,
Email: "email@example.com", Email: "email@example.com",
OrgID: "orgId", OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
@ -45,29 +46,28 @@ func TestGetJwtClaims(t *testing.T) {
retrievedClaims, err := jwtService.Claims(tokenString) retrievedClaims, err := jwtService.Claims(tokenString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, claims.UserID, retrievedClaims.UserID) assert.Equal(t, claims.UserID, retrievedClaims.UserID)
assert.Equal(t, claims.GroupID, retrievedClaims.GroupID) assert.Equal(t, claims.Role, retrievedClaims.Role)
assert.Equal(t, claims.Email, retrievedClaims.Email) assert.Equal(t, claims.Email, retrievedClaims.Email)
assert.Equal(t, claims.OrgID, retrievedClaims.OrgID) assert.Equal(t, claims.OrgID, retrievedClaims.OrgID)
} }
func TestGetJwtClaimsInvalidToken(t *testing.T) { func TestJwtClaimsInvalidToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour) jwtService := NewJWT("secret", time.Minute, time.Hour)
// Test retrieving claims from an invalid token
_, err := jwtService.Claims("invalid.token.string") _, err := jwtService.Claims("invalid.token.string")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "token is malformed") assert.Contains(t, err.Error(), "token is malformed")
} }
func TestGetJwtClaimsExpiredToken(t *testing.T) { func TestJwtClaimsExpiredToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour) jwtService := NewJWT("secret", time.Minute, time.Hour)
// Create an expired token // Create an expired token
claims := Claims{ claims := Claims{
UserID: "userId", UserID: "userId",
GroupID: "groupId", Role: RoleAdmin,
Email: "email@example.com", Email: "email@example.com",
OrgID: "orgId", OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
@ -81,15 +81,15 @@ func TestGetJwtClaimsExpiredToken(t *testing.T) {
assert.Contains(t, err.Error(), "token is expired") assert.Contains(t, err.Error(), "token is expired")
} }
func TestGetJwtClaimsInvalidSignature(t *testing.T) { func TestJwtClaimsInvalidSignature(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour) jwtService := NewJWT("secret", time.Minute, time.Hour)
// Create a valid token // Create a valid token
claims := Claims{ claims := Claims{
UserID: "userId", UserID: "userId",
GroupID: "groupId", Role: RoleAdmin,
Email: "email@example.com", Email: "email@example.com",
OrgID: "orgId", OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
}, },
@ -106,6 +106,86 @@ func TestGetJwtClaimsInvalidSignature(t *testing.T) {
assert.Contains(t, err.Error(), "signature is invalid") assert.Contains(t, err.Error(), "signature is invalid")
} }
func TestJwtClaimsWithInvalidRole(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "userId",
Role: "INVALID_ROLE",
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestJwtClaimsMissingUserID(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestJwtClaimsMissingRole(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "userId",
Role: "",
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestJwtClaimsMissingOrgID(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "userId",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestParseBearerAuth(t *testing.T) { func TestParseBearerAuth(t *testing.T) {
tests := []struct { tests := []struct {
auth string auth string

View File

@ -0,0 +1,52 @@
package authtypes
import (
"encoding/json"
"github.com/SigNoz/signoz/pkg/errors"
)
// Do not take inspiration from this. This is a hack to avoid using valuer.String and use upper case strings.
type Role string
const (
RoleAdmin Role = "ADMIN"
RoleEditor Role = "EDITOR"
RoleViewer Role = "VIEWER"
)
func NewRole(role string) (Role, error) {
switch role {
case "ADMIN":
return RoleAdmin, nil
case "EDITOR":
return RoleEditor, nil
case "VIEWER":
return RoleViewer, nil
}
return "", errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid role: %s", role)
}
func (r Role) String() string {
return string(r)
}
func (r *Role) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
role, err := NewRole(s)
if err != nil {
return err
}
*r = role
return nil
}
func (r Role) MarshalJSON() ([]byte, error) {
return json.Marshal(r.String())
}

View File

@ -16,18 +16,8 @@ type Invite struct {
Role string `bun:"role,type:text,notnull" json:"role"` Role string `bun:"role,type:text,notnull" json:"role"`
} }
type Group struct {
bun.BaseModel `bun:"table:groups"`
TimeAuditable
OrgID string `bun:"org_id,type:text"`
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull,unique" json:"name"`
}
type GettableUser struct { type GettableUser struct {
User User
Role string `json:"role"`
Organization string `json:"organization"` Organization string `json:"organization"`
} }
@ -40,7 +30,7 @@ type User struct {
Email string `bun:"email,type:text,notnull,unique" json:"email"` Email string `bun:"email,type:text,notnull,unique" json:"email"`
Password string `bun:"password,type:text,notnull" json:"-"` Password string `bun:"password,type:text,notnull" json:"-"`
ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"` ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"`
GroupID string `bun:"group_id,type:text,notnull" json:"groupId"` Role string `bun:"role,type:text,notnull" json:"role"`
OrgID string `bun:"org_id,type:text,notnull" json:"orgId"` OrgID string `bun:"org_id,type:text,notnull" json:"orgId"`
} }

View File

@ -16,6 +16,12 @@ pytest_plugins = [
def pytest_addoption(parser: pytest.Parser): def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--dev",
action="store",
default=False,
help="Run in dev mode. In this mode, the containers are not torn down after test run and are reused in subsequent test runs.",
)
parser.addoption( parser.addoption(
"--sqlstore-provider", "--sqlstore-provider",
action="store", action="store",

View File

@ -0,0 +1,3 @@
from testcontainers.core.config import testcontainers_config as config
config.ryuk_disabled = True

View File

@ -1,3 +1,4 @@
import dataclasses
import os import os
from typing import Any, Generator from typing import Any, Generator
@ -15,10 +16,46 @@ def clickhouse(
network: Network, network: Network,
zookeeper: types.TestContainerDocker, zookeeper: types.TestContainerDocker,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.TestContainerClickhouse: ) -> types.TestContainerClickhouse:
""" """
Package-scoped fixture for Clickhouse TestContainer. Package-scoped fixture for Clickhouse TestContainer.
""" """
dev = request.config.getoption("--dev")
if dev:
container = pytestconfig.cache.get("clickhouse.container", None)
env = pytestconfig.cache.get("clickhouse.env", None)
if container and env:
assert isinstance(container, dict)
assert isinstance(env, dict)
test_container = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
container["host_config"]["scheme"],
container["host_config"]["address"],
container["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
container["container_config"]["scheme"],
container["container_config"]["address"],
container["container_config"]["port"],
),
)
connection = clickhouse_driver.connect(
user=env["SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_USERNAME"],
password=env["SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_PASSWORD"],
host=test_container.host_config.address,
port=test_container.host_config.port,
)
return types.TestContainerClickhouse(
container=test_container,
conn=connection,
env=env,
)
version = request.config.getoption("--clickhouse-version") version = request.config.getoption("--clickhouse-version")
container = ClickHouseContainer( container = ClickHouseContainer(
@ -91,21 +128,37 @@ def clickhouse(
) )
def stop(): def stop():
if dev:
return
connection.close() connection.close()
container.stop(delete_volume=True) container.stop(delete_volume=True)
request.addfinalizer(stop) request.addfinalizer(stop)
return types.TestContainerClickhouse( cached_clickhouse = types.TestContainerClickhouse(
container=container, container=types.TestContainerDocker(
host_config=types.TestContainerUrlConfig( host_config=types.TestContainerUrlConfig(
"tcp", container.get_container_host_ip(), container.get_exposed_port(9000) "tcp",
), container.get_container_host_ip(),
container_config=types.TestContainerUrlConfig( container.get_exposed_port(9000),
"tcp", container.get_wrapped_container().name, 9000 ),
container_config=types.TestContainerUrlConfig(
"tcp", container.get_wrapped_container().name, 9000
),
), ),
conn=connection, conn=connection,
env={ env={
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_DSN": f"tcp://{container.username}:{container.password}@{container.get_wrapped_container().name}:{9000}" # pylint: disable=line-too-long "SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_DSN": f"tcp://{container.username}:{container.password}@{container.get_wrapped_container().name}:{9000}", # pylint: disable=line-too-long
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_USERNAME": container.username,
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_PASSWORD": container.password,
}, },
) )
if dev:
pytestconfig.cache.set(
"clickhouse.container", dataclasses.asdict(cached_clickhouse.container)
)
pytestconfig.cache.set("clickhouse.env", cached_clickhouse.env)
return cached_clickhouse

View File

@ -1,3 +1,4 @@
import dataclasses
from typing import List from typing import List
import pytest import pytest
@ -14,23 +15,44 @@ from fixtures import types
@pytest.fixture(name="zeus", scope="package") @pytest.fixture(name="zeus", scope="package")
def zeus( def zeus(
network: Network, request: pytest.FixtureRequest network: Network,
) -> types.TestContainerWiremock: request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.TestContainerDocker:
""" """
Package-scoped fixture for running zeus Package-scoped fixture for running zeus
""" """
dev = request.config.getoption("--dev")
if dev:
cached_zeus = pytestconfig.cache.get("zeus", None)
if cached_zeus:
return types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
cached_zeus["host_config"]["scheme"],
cached_zeus["host_config"]["address"],
cached_zeus["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
cached_zeus["container_config"]["scheme"],
cached_zeus["container_config"]["address"],
cached_zeus["container_config"]["port"],
),
)
container = WireMockContainer(image="wiremock/wiremock:2.35.1-1", secure=False) container = WireMockContainer(image="wiremock/wiremock:2.35.1-1", secure=False)
container.with_network(network) container.with_network(network)
container.start() container.start()
def stop(): def stop():
if dev:
return
container.stop(delete_volume=True) container.stop(delete_volume=True)
request.addfinalizer(stop) request.addfinalizer(stop)
return types.TestContainerWiremock( cached_zeus = types.TestContainerDocker(
container=container,
host_config=types.TestContainerUrlConfig( host_config=types.TestContainerUrlConfig(
"http", container.get_container_host_ip(), container.get_exposed_port(8080) "http", container.get_container_host_ip(), container.get_exposed_port(8080)
), ),
@ -39,11 +61,16 @@ def zeus(
), ),
) )
if dev:
pytestconfig.cache.set("zeus", dataclasses.asdict(cached_zeus))
return cached_zeus
@pytest.fixture(name="make_http_mocks", scope="function") @pytest.fixture(name="make_http_mocks", scope="function")
def make_http_mocks(): def make_http_mocks():
def _make_http_mocks(container: WireMockContainer, mappings: List[Mapping]): def _make_http_mocks(container: types.TestContainerDocker, mappings: List[Mapping]):
Config.base_url = container.get_url("__admin") Config.base_url = container.host_config.get("/__admin")
for mapping in mappings: for mapping in mappings:
Mappings.create_mapping(mapping=mapping) Mappings.create_mapping(mapping=mapping)

View File

@ -0,0 +1,10 @@
import logging
def setup_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
return logger

View File

@ -10,10 +10,17 @@ def migrator(
network: Network, network: Network,
clickhouse: types.TestContainerClickhouse, clickhouse: types.TestContainerClickhouse,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> None: ) -> None:
""" """
Package-scoped fixture for running schema migrations. Package-scoped fixture for running schema migrations.
""" """
dev = request.config.getoption("--dev")
if dev:
cached_migrator = pytestconfig.cache.get("migrator", None)
if cached_migrator is not None and cached_migrator is True:
return None
version = request.config.getoption("--schema-migrator-version") version = request.config.getoption("--schema-migrator-version")
client = docker.from_env() client = docker.from_env()
@ -53,3 +60,6 @@ def migrator(
raise RuntimeError("failed to run migrations on clickhouse") raise RuntimeError("failed to run migrations on clickhouse")
container.remove() container.remove()
if dev:
pytestconfig.cache.set("migrator", True)

View File

@ -1,18 +1,45 @@
import pytest import pytest
from testcontainers.core.container import Network from testcontainers.core.container import Network
from fixtures.logger import setup_logger
logger = setup_logger(__name__)
@pytest.fixture(name="network", scope="package") @pytest.fixture(name="network", scope="package")
def network(request: pytest.FixtureRequest) -> Network: def network(request: pytest.FixtureRequest, pytestconfig: pytest.Config) -> Network:
""" """
Package-Scoped fixture for creating a network Package-Scoped fixture for creating a network
""" """
nw = Network() nw = Network()
dev = request.config.getoption("--dev")
if dev:
cached_network = pytestconfig.cache.get("network", None)
if cached_network:
logger.info("Using cached Network(%s)", cached_network)
nw.id = cached_network["id"]
nw.name = cached_network["name"]
return nw
nw.create() nw.create()
def stop(): def stop():
nw.remove() dev = request.config.getoption("--dev")
if dev:
logger.info(
"Skipping removal of Network(%s)", {"name": nw.name, "id": nw.id}
)
else:
logger.info("Removing Network(%s)", {"name": nw.name, "id": nw.id})
nw.remove()
request.addfinalizer(stop) request.addfinalizer(stop)
return nw cached_network = nw
if dev:
pytestconfig.cache.set(
"network", {"name": cached_network.name, "id": cached_network.id}
)
return cached_network

View File

@ -1,3 +1,5 @@
import dataclasses
import psycopg2 import psycopg2
import pytest import pytest
from testcontainers.core.container import Network from testcontainers.core.container import Network
@ -8,11 +10,45 @@ from fixtures import types
@pytest.fixture(name="postgres", scope="package") @pytest.fixture(name="postgres", scope="package")
def postgres( def postgres(
network: Network, request: pytest.FixtureRequest network: Network, request: pytest.FixtureRequest, pytestconfig: pytest.Config
) -> types.TestContainerSQL: ) -> types.TestContainerSQL:
""" """
Package-scoped fixture for PostgreSQL TestContainer. Package-scoped fixture for PostgreSQL TestContainer.
""" """
dev = request.config.getoption("--dev")
if dev:
container = pytestconfig.cache.get("postgres.container", None)
env = pytestconfig.cache.get("postgres.env", None)
if container and env:
assert isinstance(container, dict)
assert isinstance(env, dict)
test_container = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
container["host_config"]["scheme"],
container["host_config"]["address"],
container["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
container["container_config"]["scheme"],
container["container_config"]["address"],
container["container_config"]["port"],
),
)
return types.TestContainerSQL(
container=test_container,
conn=psycopg2.connect(
dbname=env["SIGNOZ_SQLSTORE_POSTGRES_DBNAME"],
user=env["SIGNOZ_SQLSTORE_POSTGRES_USER"],
password=env["SIGNOZ_SQLSTORE_POSTGRES_PASSWORD"],
host=test_container.host_config.address,
port=test_container.host_config.port,
),
env=env,
)
version = request.config.getoption("--postgres-version") version = request.config.getoption("--postgres-version")
container = PostgresContainer( container = PostgresContainer(
@ -35,24 +71,39 @@ def postgres(
) )
def stop(): def stop():
if dev:
return
connection.close() connection.close()
container.stop(delete_volume=True) container.stop(delete_volume=True)
request.addfinalizer(stop) request.addfinalizer(stop)
return types.TestContainerSQL( cached_postgres = types.TestContainerSQL(
container=container, container=types.TestContainerDocker(
host_config=types.TestContainerUrlConfig( host_config=types.TestContainerUrlConfig(
"postgresql", "postgresql",
container.get_container_host_ip(), container.get_container_host_ip(),
container.get_exposed_port(5432), container.get_exposed_port(5432),
), ),
container_config=types.TestContainerUrlConfig( container_config=types.TestContainerUrlConfig(
"postgresql", container.get_wrapped_container().name, 5432 "postgresql", container.get_wrapped_container().name, 5432
),
), ),
conn=connection, conn=connection,
env={ env={
"SIGNOZ_SQLSTORE_PROVIDER": "postgres", "SIGNOZ_SQLSTORE_PROVIDER": "postgres",
"SIGNOZ_SQLSTORE_POSTGRES_DSN": f"postgresql://{container.username}:{container.password}@{container.get_wrapped_container().name}:{5432}/{container.dbname}", # pylint: disable=line-too-long "SIGNOZ_SQLSTORE_POSTGRES_DSN": f"postgresql://{container.username}:{container.password}@{container.get_wrapped_container().name}:{5432}/{container.dbname}", # pylint: disable=line-too-long
"SIGNOZ_SQLSTORE_POSTGRES_DBNAME": container.dbname,
"SIGNOZ_SQLSTORE_POSTGRES_USER": container.username,
"SIGNOZ_SQLSTORE_POSTGRES_PASSWORD": container.password,
}, },
) )
if dev:
pytestconfig.cache.set(
"postgres.container", dataclasses.asdict(cached_postgres.container)
)
pytestconfig.cache.set("postgres.env", cached_postgres.env)
return cached_postgres

View File

@ -1,3 +1,4 @@
import dataclasses
import platform import platform
import time import time
from http import HTTPStatus from http import HTTPStatus
@ -8,20 +9,44 @@ from testcontainers.core.container import DockerContainer, Network
from testcontainers.core.image import DockerImage from testcontainers.core.image import DockerImage
from fixtures import types from fixtures import types
from fixtures.logger import setup_logger
logger = setup_logger(__name__)
@pytest.fixture(name="signoz", scope="package") @pytest.fixture(name="signoz", scope="package")
def signoz( def signoz(
network: Network, network: Network,
zeus: types.TestContainerWiremock, zeus: types.TestContainerDocker,
sqlstore: types.TestContainerSQL, sqlstore: types.TestContainerSQL,
clickhouse: types.TestContainerClickhouse, clickhouse: types.TestContainerClickhouse,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.SigNoz: ) -> types.SigNoz:
""" """
Package-scoped fixture for setting up SigNoz. Package-scoped fixture for setting up SigNoz.
""" """
dev = request.config.getoption("--dev")
if dev:
cached_signoz = pytestconfig.cache.get("signoz.container", None)
if cached_signoz:
self = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
cached_signoz["host_config"]["scheme"],
cached_signoz["host_config"]["address"],
cached_signoz["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
cached_signoz["container_config"]["scheme"],
cached_signoz["container_config"]["address"],
cached_signoz["container_config"]["port"],
),
)
return types.SigNoz(
self=self, sqlstore=sqlstore, telemetrystore=clickhouse, zeus=zeus
)
# Run the migrations for clickhouse # Run the migrations for clickhouse
request.getfixturevalue("migrator") request.getfixturevalue("migrator")
@ -70,7 +95,7 @@ def signoz(
container.start() container.start()
def ready(container: DockerContainer) -> None: def ready(container: DockerContainer) -> None:
for attempt in range(30): for attempt in range(5):
try: try:
response = requests.get( response = requests.get(
f"http://{container.get_container_host_ip()}:{container.get_exposed_port(8080)}/api/v1/health", # pylint: disable=line-too-long f"http://{container.get_container_host_ip()}:{container.get_exposed_port(8080)}/api/v1/health", # pylint: disable=line-too-long
@ -78,22 +103,31 @@ def signoz(
) )
return response.status_code == HTTPStatus.OK return response.status_code == HTTPStatus.OK
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
print(f"attempt {attempt} at health check failed") logger.info(
"Attempt %s at readiness check for SigNoz container %s failed, going to retry ...", # pylint: disable=line-too-long
attempt + 1,
container,
)
time.sleep(2) time.sleep(2)
raise TimeoutError("timeout exceeded while waiting") raise TimeoutError("timeout exceeded while waiting")
ready(container=container) try:
ready(container=container)
except Exception as e: # pylint: disable=broad-exception-caught
raise e
def stop(): def stop():
logs = container.get_wrapped_container().logs(tail=100) if dev:
print(logs.decode(encoding="utf-8")) logger.info("Skipping removal of SigNoz container %s ...", container)
container.stop(delete_volume=True) return
else:
logger.info("Removing SigNoz container %s ...", container)
container.stop(delete_volume=True)
request.addfinalizer(stop) request.addfinalizer(stop)
return types.SigNoz( cached_signoz = types.SigNoz(
self=types.TestContainerDocker( self=types.TestContainerDocker(
container=container,
host_config=types.TestContainerUrlConfig( host_config=types.TestContainerUrlConfig(
"http", "http",
container.get_container_host_ip(), container.get_container_host_ip(),
@ -109,3 +143,10 @@ def signoz(
telemetrystore=clickhouse, telemetrystore=clickhouse,
zeus=zeus, zeus=zeus,
) )
if dev:
pytestconfig.cache.set(
"signoz.container", dataclasses.asdict(cached_signoz.self)
)
return cached_signoz

View File

@ -11,27 +11,59 @@ ConnectionTuple = namedtuple("ConnectionTuple", "connection config")
@pytest.fixture(name="sqlite", scope="package") @pytest.fixture(name="sqlite", scope="package")
def sqlite( def sqlite(
tmpfs: Generator[types.LegacyPath, Any, None], request: pytest.FixtureRequest tmpfs: Generator[types.LegacyPath, Any, None],
request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.TestContainerSQL: ) -> types.TestContainerSQL:
""" """
Package-scoped fixture for SQLite. Package-scoped fixture for SQLite.
""" """
dev = request.config.getoption("--dev")
if dev:
container = pytestconfig.cache.get("sqlite.container", None)
env = pytestconfig.cache.get("sqlite.env", None)
if container and env:
assert isinstance(container, dict)
assert isinstance(env, dict)
return types.TestContainerSQL(
container=types.TestContainerDocker(
host_config=None,
container_config=None,
),
conn=sqlite3.connect(
env["SIGNOZ_SQLSTORE_SQLITE_PATH"], check_same_thread=False
),
env=env,
)
tmpdir = tmpfs("sqlite") tmpdir = tmpfs("sqlite")
path = tmpdir / "signoz.db" path = tmpdir / "signoz.db"
connection = sqlite3.connect(path, check_same_thread=False) connection = sqlite3.connect(path, check_same_thread=False)
def stop(): def stop():
if dev:
return
connection.close() connection.close()
request.addfinalizer(stop) request.addfinalizer(stop)
return types.TestContainerSQL( cached_sqlite = types.TestContainerSQL(
None, container=types.TestContainerDocker(
host_config=None, host_config=None,
container_config=None, container_config=None,
),
conn=connection, conn=connection,
env={ env={
"SIGNOZ_SQLSTORE_PROVIDER": "sqlite", "SIGNOZ_SQLSTORE_PROVIDER": "sqlite",
"SIGNOZ_SQLSTORE_SQLITE_PATH": str(path), "SIGNOZ_SQLSTORE_SQLITE_PATH": str(path),
}, },
) )
if dev:
pytestconfig.cache.set("sqlite.env", cached_sqlite.env)
return cached_sqlite

View File

@ -4,8 +4,6 @@ from urllib.parse import urljoin
import py import py
from clickhouse_driver.dbapi import Connection from clickhouse_driver.dbapi import Connection
from testcontainers.core.container import DockerContainer
from wiremock.testing.testcontainer import WireMockContainer
LegacyPath = py.path.local LegacyPath = py.path.local
@ -27,29 +25,22 @@ class TestContainerUrlConfig:
@dataclass @dataclass
class TestContainerDocker: class TestContainerDocker:
__test__ = False __test__ = False
container: DockerContainer
host_config: TestContainerUrlConfig host_config: TestContainerUrlConfig
container_config: TestContainerUrlConfig container_config: TestContainerUrlConfig
@dataclass @dataclass
class TestContainerWiremock(TestContainerDocker): class TestContainerSQL:
__test__ = False __test__ = False
container: WireMockContainer container: TestContainerDocker
@dataclass
class TestContainerSQL(TestContainerDocker):
__test__ = False
container: DockerContainer
conn: any conn: any
env: Dict[str, str] env: Dict[str, str]
@dataclass @dataclass
class TestContainerClickhouse(TestContainerDocker): class TestContainerClickhouse:
__test__ = False __test__ = False
container: DockerContainer container: TestContainerDocker
conn: Connection conn: Connection
env: Dict[str, str] env: Dict[str, str]
@ -60,4 +51,4 @@ class SigNoz:
self: TestContainerDocker self: TestContainerDocker
sqlstore: TestContainerSQL sqlstore: TestContainerSQL
telemetrystore: TestContainerClickhouse telemetrystore: TestContainerClickhouse
zeus: TestContainerWiremock zeus: TestContainerDocker

View File

@ -1,3 +1,5 @@
import dataclasses
import pytest import pytest
from testcontainers.core.container import DockerContainer, Network from testcontainers.core.container import DockerContainer, Network
@ -6,11 +8,29 @@ from fixtures import types
@pytest.fixture(name="zookeeper", scope="package") @pytest.fixture(name="zookeeper", scope="package")
def zookeeper( def zookeeper(
network: Network, request: pytest.FixtureRequest network: Network, request: pytest.FixtureRequest, pytestconfig: pytest.Config
) -> types.TestContainerDocker: ) -> types.TestContainerDocker:
""" """
Package-scoped fixture for Zookeeper TestContainer. Package-scoped fixture for Zookeeper TestContainer.
""" """
dev = request.config.getoption("--dev")
if dev:
cached_zookeeper = pytestconfig.cache.get("zookeeper", None)
if cached_zookeeper:
return types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
cached_zookeeper["host_config"]["scheme"],
cached_zookeeper["host_config"]["address"],
cached_zookeeper["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
cached_zookeeper["container_config"]["scheme"],
cached_zookeeper["container_config"]["address"],
cached_zookeeper["container_config"]["port"],
),
)
version = request.config.getoption("--zookeeper-version") version = request.config.getoption("--zookeeper-version")
container = DockerContainer(image=f"bitnami/zookeeper:{version}") container = DockerContainer(image=f"bitnami/zookeeper:{version}")
@ -21,12 +41,14 @@ def zookeeper(
container.start() container.start()
def stop(): def stop():
if dev:
return
container.stop(delete_volume=True) container.stop(delete_volume=True)
request.addfinalizer(stop) request.addfinalizer(stop)
return types.TestContainerDocker( cached_zookeeper = types.TestContainerDocker(
container=container,
host_config=types.TestContainerUrlConfig( host_config=types.TestContainerUrlConfig(
"tcp", "tcp",
container.get_container_host_ip(), container.get_container_host_ip(),
@ -38,3 +60,8 @@ def zookeeper(
2181, 2181,
), ),
) )
if dev:
pytestconfig.cache.set("zookeeper", dataclasses.asdict(cached_zookeeper))
return cached_zookeeper

View File

@ -27,6 +27,10 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options] [tool.pytest.ini_options]
python_files = "src/**/**.py" python_files = "src/**/**.py"
log_cli = true
log_format = "%(asctime)s [%(levelname)s] (%(filename)s:%(lineno)s) %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"
addopts = "-ra"
[tool.pylint.main] [tool.pylint.main]
ignore = [".venv"] ignore = [".venv"]

View File

@ -3,9 +3,12 @@ from http import HTTPStatus
import requests import requests
from fixtures import types from fixtures import types
from fixtures.logger import setup_logger
logger = setup_logger(__name__)
def test_register(signoz: types.SigNoz) -> None: def test_register(signoz: types.SigNoz, get_jwt_token) -> None:
response = requests.get(signoz.self.host_config.get("/api/v1/version"), timeout=2) response = requests.get(signoz.self.host_config.get("/api/v1/version"), timeout=2)
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
@ -16,8 +19,8 @@ def test_register(signoz: types.SigNoz) -> None:
json={ json={
"name": "admin", "name": "admin",
"orgId": "", "orgId": "",
"orgName": "", "orgName": "integration.test",
"email": "admin@admin.com", "email": "admin@integration.test",
"password": "password", "password": "password",
}, },
timeout=2, timeout=2,
@ -28,3 +31,175 @@ def test_register(signoz: types.SigNoz) -> None:
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert response.json()["setupCompleted"] is True assert response.json()["setupCompleted"] is True
admin_token = get_jwt_token("admin@integration.test", "password")
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "admin@integration.test"),
None,
)
assert found_user is not None
assert found_user["role"] == "ADMIN"
response = requests.get(
signoz.self.host_config.get(f"/api/v1/rbac/role/{found_user["id"]}"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
assert response.json()["group_name"] == "ADMIN"
def test_invite_and_register(signoz: types.SigNoz, get_jwt_token) -> None:
# Generate an invite token for the editor user
response = requests.post(
signoz.self.host_config.get("/api/v1/invite"),
json={"email": "editor@integration.test", "role": "EDITOR"},
timeout=2,
headers={
"Authorization": f"Bearer {get_jwt_token("admin@integration.test", "password")}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.OK
invite_response = response.json()
assert "email" in invite_response
assert "inviteToken" in invite_response
assert invite_response["email"] == "editor@integration.test"
# Register the editor user using the invite token
response = requests.post(
signoz.self.host_config.get("/api/v1/register"),
json={
"email": "editor@integration.test",
"password": "password",
"name": "editor",
"token": f"{invite_response["inviteToken"]}",
},
timeout=2,
)
assert response.status_code == HTTPStatus.OK
# Verify that the invite token has been deleted
response = requests.get(
signoz.self.host_config.get(
f"/api/v1/invite/{invite_response["inviteToken"]}"
), # pylint: disable=line-too-long
timeout=2,
)
assert response.status_code in (HTTPStatus.NOT_FOUND, HTTPStatus.BAD_REQUEST)
# Verify that an admin endpoint cannot be called by the editor user
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={
"Authorization": f"Bearer {get_jwt_token("editor@integration.test", "password")}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.FORBIDDEN
# Verify that the editor has been created
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={
"Authorization": f"Bearer {get_jwt_token("admin@integration.test", "password")}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "editor@integration.test"),
None,
)
assert found_user is not None
assert found_user["role"] == "EDITOR"
assert found_user["name"] == "editor"
assert found_user["email"] == "editor@integration.test"
def test_revoke_invite_and_register(signoz: types.SigNoz, get_jwt_token) -> None:
admin_token = get_jwt_token("admin@integration.test", "password")
# Generate an invite token for the viewer user
response = requests.post(
signoz.self.host_config.get("/api/v1/invite"),
json={"email": "viewer@integration.test", "role": "VIEWER"},
timeout=2,
headers={
"Authorization": f"Bearer {admin_token}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.OK
invite_response = response.json()
assert "email" in invite_response
assert "inviteToken" in invite_response
response = requests.delete(
signoz.self.host_config.get(f"/api/v1/invite/{invite_response['email']}"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
# Try registering the viewer user with the invite token
response = requests.post(
signoz.self.host_config.get("/api/v1/register"),
json={
"email": "viewer@integration.test",
"password": "password",
"name": "viewer",
"token": f"{invite_response["inviteToken"]}",
},
timeout=2,
)
assert response.status_code in (HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND)
def test_self_access(signoz: types.SigNoz, get_jwt_token) -> None:
admin_token = get_jwt_token("admin@integration.test", "password")
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "editor@integration.test"),
None,
)
response = requests.get(
signoz.self.host_config.get(f"/api/v1/rbac/role/{found_user['id']}"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
assert response.json()["group_name"] == "EDITOR"

View File

@ -14,7 +14,7 @@ from fixtures.types import SigNoz
def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None: def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None:
make_http_mocks( make_http_mocks(
signoz.zeus.container, signoz.zeus,
[ [
Mapping( Mapping(
request=MappingRequest( request=MappingRequest(
@ -51,7 +51,7 @@ def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None:
], ],
) )
access_token = get_jwt_token("admin@admin.com", "password") access_token = get_jwt_token("admin@integration.test", "password")
response = requests.post( response = requests.post(
url=signoz.self.host_config.get("/api/v3/licenses"), url=signoz.self.host_config.get("/api/v3/licenses"),

View File

@ -0,0 +1,54 @@
from fixtures import types
import requests
from http import HTTPStatus
def test_api_key(signoz: types.SigNoz, get_jwt_token) -> None:
admin_token = get_jwt_token("admin@integration.test", "password")
response = requests.post(
signoz.self.host_config.get("/api/v1/pats"),
headers={"Authorization": f"Bearer {admin_token}"},
json={
"name": "admin",
"role": "ADMIN",
"expiresInDays": 1,
},
)
assert response.status_code == HTTPStatus.OK
pat_response = response.json()
assert "data" in pat_response
assert "token" in pat_response["data"]
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={"SIGNOZ-API-KEY": f"{pat_response["data"]["token"]}"},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "admin@integration.test"),
None,
)
response = requests.get(
signoz.self.host_config.get("/api/v1/pats"),
headers={"SIGNOZ-API-KEY": f"{pat_response["data"]["token"]}"},
)
assert response.status_code == HTTPStatus.OK
assert "data" in response.json()
found_pat = next(
(pat for pat in response.json()["data"] if pat["userId"] == found_user["id"]),
None,
)
assert found_pat is not None
assert found_pat["userId"] == found_user["id"]
assert found_pat["name"] == "admin"
assert found_pat["role"] == "ADMIN"