feat(auth): drop group table (#7672)

### Summary

drop group table
This commit is contained in:
Vibhu Pandey 2025-04-26 15:50:02 +05:30 committed by GitHub
parent b60588a749
commit 9e449e2858
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
84 changed files with 1762 additions and 1227 deletions

View File

@ -0,0 +1,27 @@
services:
postgres:
image: postgres:15
container_name: postgres
environment:
POSTGRES_DB: signoz
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
healthcheck:
test:
[
"CMD",
"pg_isready",
"-d",
"signoz",
"-U",
"postgres"
]
interval: 30s
timeout: 30s
retries: 3
restart: on-failure
ports:
- "127.0.0.1:5432:5432/tcp"
volumes:
- ${PWD}/fs/tmp/var/lib/postgresql/data/:/var/lib/postgresql/data/

View File

@ -44,10 +44,8 @@ jobs:
- name: run
run: |
cd tests/integration && \
poetry run pytest -ra \
poetry run pytest \
--basetemp=./tmp/ \
-vv \
--capture=no \
src/${{matrix.src}} \
--sqlstore-provider ${{matrix.sqlstore-provider}} \
--postgres-version ${{matrix.postgres-version}} \

View File

@ -56,6 +56,11 @@ devenv-clickhouse: ## Run clickhouse in devenv
@cd .devenv/docker/clickhouse; \
docker compose -f compose.yaml up -d
.PHONY: devenv-postgres
devenv-postgres: ## Run postgres in devenv
@cd .devenv/docker/postgres; \
docker compose -f compose.yaml up -d
##############################################################
# go commands
##############################################################

View File

@ -61,11 +61,17 @@ func (p *Pat) Wrap(next http.Handler) http.Handler {
return
}
role, err := authtypes.NewRole(user.Role)
if err != nil {
next.ServeHTTP(w, r)
return
}
jwt := authtypes.Claims{
UserID: user.ID,
GroupID: user.GroupID,
Email: user.Email,
OrgID: user.OrgID,
UserID: user.ID,
Role: role,
Email: user.Email,
OrgID: user.OrgID,
}
ctx = authtypes.NewContextWithClaims(ctx, jwt)

View File

@ -12,6 +12,7 @@ import (
"github.com/SigNoz/signoz/ee/query-service/usage"
"github.com/SigNoz/signoz/pkg/alertmanager"
"github.com/SigNoz/signoz/pkg/apis/fields"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/modules/preference"
preferencecore "github.com/SigNoz/signoz/pkg/modules/preference/core"
@ -126,7 +127,7 @@ func (ah *APIHandler) CheckFeature(f string) bool {
}
// RegisterRoutes registers routes for this handler on the given router
func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddleware) {
func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *middleware.AuthZ) {
// note: add ee override methods first
// routes available only in ee version
@ -199,7 +200,7 @@ func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddlew
}
func (ah *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *baseapp.AuthMiddleware) {
func (ah *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *middleware.AuthZ) {
ah.APIHandler.RegisterCloudIntegrationsRoutes(router, am)

View File

@ -12,11 +12,11 @@ import (
"github.com/SigNoz/signoz/ee/query-service/constants"
eeTypes "github.com/SigNoz/signoz/ee/types"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/auth"
baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
basemodel "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/google/uuid"
"github.com/gorilla/mux"
"go.uber.org/zap"
@ -30,6 +30,12 @@ type CloudIntegrationConnectionParamsResponse struct {
}
func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseWriter, r *http.Request) {
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
cloudProvider := mux.Vars(r)["cloudProvider"]
if cloudProvider != "aws" {
RespondError(w, basemodel.BadRequest(fmt.Errorf(
@ -38,15 +44,7 @@ func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseW
return
}
currentUser, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, basemodel.UnauthorizedError(fmt.Errorf(
"couldn't deduce current user: %w", err,
)), nil)
return
}
apiKey, apiErr := ah.getOrCreateCloudIntegrationPAT(r.Context(), currentUser.OrgID, cloudProvider)
apiKey, apiErr := ah.getOrCreateCloudIntegrationPAT(r.Context(), claims.OrgID, cloudProvider)
if apiErr != nil {
RespondError(w, basemodel.WrapApiError(
apiErr, "couldn't provision PAT for cloud integration:",
@ -137,7 +135,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationPAT(ctx context.Context, orgId
newPAT := eeTypes.NewGettablePAT(
integrationPATName,
baseconstants.ViewerGroup,
authtypes.RoleViewer.String(),
integrationUser.ID,
0,
)
@ -181,11 +179,7 @@ func (ah *APIHandler) getOrCreateCloudIntegrationUser(
OrgID: orgId,
}
viewerGroup, apiErr := dao.DB().GetGroupByName(ctx, baseconstants.ViewerGroup)
if apiErr != nil {
return nil, basemodel.WrapApiError(apiErr, "couldn't get viewer group for creating integration user")
}
newUser.GroupID = viewerGroup.ID
newUser.Role = authtypes.RoleViewer.String()
passwordHash, err := auth.PasswordHash(uuid.NewString())
if err != nil {

View File

@ -7,7 +7,6 @@ import (
"github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/app/dashboards"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/gorilla/mux"
)
@ -36,18 +35,19 @@ func (ah *APIHandler) lockUnlockDashboard(w http.ResponseWriter, r *http.Request
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
return
}
dashboard, err := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid)
if err != nil {
render.Error(w, errors.Wrapf(err, errors.TypeInternal, errors.CodeInternal, "failed to get dashboard"))
return
}
if !auth.IsAdminV2(claims) && (dashboard.CreatedBy != claims.Email) {
if err := claims.IsAdmin(); err != nil && (dashboard.CreatedBy != claims.Email) {
render.Error(w, errors.Newf(errors.TypeForbidden, errors.CodeForbidden, "You are not authorized to lock/unlock this dashboard"))
return
}

View File

@ -1,7 +1,6 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -13,35 +12,31 @@ import (
"github.com/SigNoz/signoz/pkg/errors"
errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/auth"
baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants"
basemodel "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/valuer"
"github.com/gorilla/mux"
"go.uber.org/zap"
)
func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
req := model.CreatePATRequestBody{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
RespondError(w, model.BadRequest(err), nil)
return
}
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
pat := eeTypes.NewGettablePAT(
req.Name,
req.Role,
user.ID,
claims.UserID,
req.ExpiresInDays,
)
err = validatePATRequest(pat)
@ -52,7 +47,7 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
zap.L().Info("Got Create PAT request", zap.Any("pat", pat))
var apierr basemodel.BaseApiError
if pat, apierr = ah.AppDao().CreatePAT(ctx, user.OrgID, pat); apierr != nil {
if pat, apierr = ah.AppDao().CreatePAT(r.Context(), claims.OrgID, pat); apierr != nil {
RespondError(w, apierr, nil)
return
}
@ -61,20 +56,28 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
}
func validatePATRequest(req eeTypes.GettablePAT) error {
if req.Role == "" || (req.Role != baseconstants.ViewerGroup && req.Role != baseconstants.EditorGroup && req.Role != baseconstants.AdminGroup) {
return fmt.Errorf("valid role is required")
_, err := authtypes.NewRole(req.Role)
if err != nil {
return err
}
if req.ExpiresAt < 0 {
return fmt.Errorf("valid expiresAt is required")
}
if req.Name == "" {
return fmt.Errorf("valid name is required")
}
return nil
}
func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
req := eeTypes.GettablePAT{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -89,24 +92,15 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
return
}
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
//get the pat
existingPAT, paterr := ah.AppDao().GetPATByID(ctx, user.OrgID, id)
existingPAT, paterr := ah.AppDao().GetPATByID(r.Context(), claims.OrgID, id)
if paterr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error()))
return
}
// get the user
createdByUser, usererr := ah.AppDao().GetUser(ctx, existingPAT.UserID)
createdByUser, usererr := ah.AppDao().GetUser(r.Context(), existingPAT.UserID)
if usererr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error()))
return
@ -123,11 +117,11 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
return
}
req.UpdatedByUserID = user.ID
req.UpdatedByUserID = claims.UserID
req.UpdatedAt = time.Now()
zap.L().Info("Got Update PAT request", zap.Any("pat", req))
var apierr basemodel.BaseApiError
if apierr = ah.AppDao().UpdatePAT(ctx, user.OrgID, req, id); apierr != nil {
if apierr = ah.AppDao().UpdatePAT(r.Context(), claims.OrgID, req, id); apierr != nil {
RespondError(w, apierr, nil)
return
}
@ -136,50 +130,44 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
}
func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
user, err := auth.GetUserFromReqContext(r.Context())
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
render.Error(w, err)
return
}
zap.L().Info("Get PATs for user", zap.String("user_id", user.ID))
pats, apierr := ah.AppDao().ListPATs(ctx, user.OrgID)
pats, apierr := ah.AppDao().ListPATs(r.Context(), claims.OrgID)
if apierr != nil {
RespondError(w, apierr, nil)
return
}
ah.Respond(w, pats)
}
func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
idStr := mux.Vars(r)["id"]
id, err := valuer.NewUUID(idStr)
if err != nil {
render.Error(w, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "id is not a valid uuid-v7"))
return
}
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
//get the pat
existingPAT, paterr := ah.AppDao().GetPATByID(ctx, user.OrgID, id)
existingPAT, paterr := ah.AppDao().GetPATByID(r.Context(), claims.OrgID, id)
if paterr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, paterr.Error()))
return
}
// get the user
createdByUser, usererr := ah.AppDao().GetUser(ctx, existingPAT.UserID)
createdByUser, usererr := ah.AppDao().GetUser(r.Context(), existingPAT.UserID)
if usererr != nil {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, usererr.Error()))
return
@ -191,7 +179,7 @@ func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) {
}
zap.L().Info("Revoke PAT with id", zap.String("id", id.StringValue()))
if apierr := ah.AppDao().RevokePAT(ctx, user.OrgID, id, user.ID); apierr != nil {
if apierr := ah.AppDao().RevokePAT(r.Context(), claims.OrgID, id, claims.UserID); apierr != nil {
RespondError(w, apierr, nil)
return
}

View File

@ -2,7 +2,6 @@ package app
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@ -22,11 +21,9 @@ import (
"github.com/SigNoz/signoz/pkg/alertmanager"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/prometheus"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/telemetrystore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/web"
"github.com/rs/cors"
@ -334,24 +331,8 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,
}
func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*http.Server, error) {
r := baseapp.NewRouter()
// add auth middleware
getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) {
user, err := auth.GetUserFromReqContext(ctx)
if err != nil {
return nil, err
}
if user.User.OrgID == "" {
return nil, basemodel.UnauthorizedError(errors.New("orgId is missing in the claims"))
}
return user, nil
}
am := baseapp.NewAuthMiddleware(getUserFromRequest)
am := middleware.NewAuthZ(s.serverOptions.SigNoz.Instrumentation.Logger())
r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
r.Use(eemiddleware.NewPat(s.serverOptions.SigNoz.SQLStore, []string{"SIGNOZ-API-KEY"}).Wrap)

View File

@ -9,7 +9,6 @@ import (
"github.com/SigNoz/signoz/ee/query-service/constants"
"github.com/SigNoz/signoz/ee/query-service/model"
baseauth "github.com/SigNoz/signoz/pkg/query-service/auth"
baseconst "github.com/SigNoz/signoz/pkg/query-service/constants"
basemodel "github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/query-service/utils"
"github.com/SigNoz/signoz/pkg/types"
@ -36,12 +35,6 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (
return nil, model.InternalErrorStr("failed to generate password hash")
}
group, apiErr := m.GetGroupByName(ctx, baseconst.ViewerGroup)
if apiErr != nil {
zap.L().Error("GetGroupByName failed", zap.Error(apiErr))
return nil, apiErr
}
user := &types.User{
ID: uuid.New().String(),
Name: "",
@ -51,11 +44,11 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (
CreatedAt: time.Now(),
},
ProfilePictureURL: "", // Currently unused
GroupID: group.ID,
Role: authtypes.RoleViewer.String(),
OrgID: domain.OrgID,
}
user, apiErr = m.CreateUser(ctx, user, false)
user, apiErr := m.CreateUser(ctx, user, false)
if apiErr != nil {
zap.L().Error("CreateUser failed", zap.Error(apiErr))
return nil, apiErr
@ -115,7 +108,7 @@ func (m *modelDao) CanUsePassword(ctx context.Context, email string) (bool, base
return false, baseapierr
}
if userPayload.Role != baseconst.AdminGroup {
if userPayload.Role != authtypes.RoleAdmin.String() {
return false, model.BadRequest(fmt.Errorf("auth method not supported"))
}

View File

@ -239,8 +239,8 @@ func (lm *Manager) ValidateV3(ctx context.Context) (reterr error) {
func (lm *Manager) ActivateV3(ctx context.Context, licenseKey string) (licenseResponse *model.LicenseV3, errResponse *model.ApiError) {
defer func() {
if errResponse != nil {
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_LICENSE_ACT_FAILED,
map[string]interface{}{"err": errResponse.Err.Error()}, claims.Email, true, false)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/SigNoz/signoz/pkg/config"
"github.com/SigNoz/signoz/pkg/config/envprovider"
"github.com/SigNoz/signoz/pkg/config/fileprovider"
"github.com/SigNoz/signoz/pkg/query-service/auth"
baseconst "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/sqlstore/sqlstorehook"
@ -147,10 +146,6 @@ func main() {
zap.L().Fatal("Could not start server", zap.Error(err))
}
if err := auth.InitAuthCache(context.Background()); err != nil {
zap.L().Fatal("Failed to initialize auth cache", zap.Error(err))
}
signoz.Start(context.Background())
if err := signoz.Wait(context.Background()); err != nil {

View File

@ -28,10 +28,9 @@ var (
CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE`
)
type dialect struct {
}
type dialect struct{}
func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
columnType, err := dialect.GetColumnType(ctx, bun, table, column)
if err != nil {
return err
@ -78,7 +77,7 @@ func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB,
return nil
}
func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
columnExists, err := dialect.ColumnExists(ctx, bun, table, column)
if err != nil {
return err
@ -420,3 +419,26 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel
return nil
}
func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bunIDB bun.IDB, model interface{}, column string) error {
existingTable := bunIDB.Dialect().Tables().Get(reflect.TypeOf(model))
columnExists, err := dialect.ColumnExists(ctx, bunIDB, existingTable.Name, column)
if err != nil {
return err
}
if !columnExists {
return nil
}
_, err = bunIDB.
NewDropColumn().
Model(model).
Column(column).
Exec(ctx)
if err != nil {
return err
}
return nil
}

1
frontend/.gitignore vendored
View File

@ -1,3 +1,4 @@
# Sentry Config File
.env.sentry-build-plugin
.qodo

View File

@ -23,7 +23,6 @@ function getUserDefaults(): IUser {
organization: '',
orgId: '',
role: 'VIEWER',
groupId: '',
};
}

View File

@ -151,7 +151,6 @@ export function getAppContextMock(
organization: 'Nightswatch',
orgId: 'does-not-matter-id',
role: role as ROLES,
groupId: 'does-not-matter-groupId',
},
org: [
{

View File

@ -15,5 +15,4 @@ export interface PayloadProps {
profilePictureURL: string;
organization: string;
role: ROLES;
groupId: string;
}

View File

@ -28,9 +28,9 @@ func (api *API) GetAlerts(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}
@ -53,9 +53,9 @@ func (api *API) TestReceiver(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}
@ -85,9 +85,9 @@ func (api *API) ListChannels(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}
@ -122,9 +122,9 @@ func (api *API) GetChannelByID(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}
@ -159,9 +159,9 @@ func (api *API) UpdateChannelByID(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}
@ -209,9 +209,9 @@ func (api *API) DeleteChannelByID(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}
@ -246,9 +246,9 @@ func (api *API) CreateChannel(rw http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
render.Error(rw, err)
return
}

View File

@ -46,8 +46,8 @@ func (a *Analytics) Wrap(next http.Handler) http.Handler {
}
if _, ok := telemetry.EnabledPaths()[path]; ok {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, err := authtypes.ClaimsFromContext(r.Context())
if err == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false)
}
}
@ -134,8 +134,8 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str
data["queryType"] = queryInfoResult.QueryType
data["panelType"] = queryInfoResult.PanelType
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, err := authtypes.ClaimsFromContext(r.Context())
if err == nil {
// switch case to set data["screen"] based on the referrer
switch {
case dashboardMatched:

View File

@ -28,9 +28,7 @@ func (a *Auth) Wrap(next http.Handler) http.Handler {
values = append(values, r.Header.Get(header))
}
ctx, err := a.jwt.ContextFromRequest(
r.Context(),
values...)
ctx, err := a.jwt.ContextFromRequest(r.Context(), values...)
if err != nil {
next.ServeHTTP(w, r)
return

View File

@ -0,0 +1,105 @@
package middleware
import (
"log/slog"
"net/http"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/gorilla/mux"
)
const (
authzDeniedMessage string = "::AUTHZ-DENIED::"
)
type AuthZ struct {
logger *slog.Logger
}
func NewAuthZ(logger *slog.Logger) *AuthZ {
if logger == nil {
panic("cannot build authz middleware, logger is empty")
}
return &AuthZ{logger: logger}
}
func (middleware *AuthZ) ViewAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
if err := claims.IsViewer(); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) EditAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
if err := claims.IsEditor(); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) AdminAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
if err := claims.IsAdmin(); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) SelfAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
claims, err := authtypes.ClaimsFromContext(req.Context())
if err != nil {
render.Error(rw, err)
return
}
id := mux.Vars(req)["id"]
if err := claims.IsSelfAccess(id); err != nil {
middleware.logger.WarnContext(req.Context(), authzDeniedMessage, "claims", claims)
render.Error(rw, err)
return
}
next(rw, req)
})
}
func (middleware *AuthZ) OpenAccess(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next(rw, req)
})
}

View File

@ -136,8 +136,8 @@ func (middleware *Logging) getLogCommentKVs(r *http.Request) map[string]string {
}
var email string
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, err := authtypes.ClaimsFromContext(r.Context())
if err == nil {
email = claims.Email
}

View File

@ -58,6 +58,8 @@ func Error(rw http.ResponseWriter, cause error) {
httpCode = http.StatusUnauthorized
case errors.TypeUnsupported:
httpCode = http.StatusNotImplemented
case errors.TypeForbidden:
httpCode = http.StatusForbidden
}
rea := make([]responseerroradditional, len(a))

View File

@ -21,11 +21,12 @@ func NewAPI(module organization.Module) organization.API {
}
func (api *organizationAPI) Get(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
orgID, err := valuer.NewUUID(claims.OrgID)
if err != nil {
render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id"))
@ -52,11 +53,12 @@ func (api *organizationAPI) GetAll(rw http.ResponseWriter, r *http.Request) {
}
func (api *organizationAPI) Update(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
orgID, err := valuer.NewUUID(claims.OrgID)
if err != nil {
render.Error(rw, errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid org id"))

View File

@ -4,7 +4,6 @@ import (
"encoding/json"
"net/http"
errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/types/preferencetypes"
@ -30,9 +29,9 @@ func NewAPI(usecase Usecase) API {
func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
preference, err := p.usecase.GetOrgPreference(
@ -49,18 +48,18 @@ func (p *preferenceAPI) GetOrgPreference(rw http.ResponseWriter, r *http.Request
func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"]
req := preferencetypes.UpdatablePreference{}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
return
}
err := json.NewDecoder(r.Body).Decode(&req)
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
render.Error(rw, err)
return
}
err = p.usecase.UpdateOrgPreference(r.Context(), preferenceId, req.PreferenceValue, claims.OrgID)
if err != nil {
render.Error(rw, err)
@ -71,9 +70,9 @@ func (p *preferenceAPI) UpdateOrgPreference(rw http.ResponseWriter, r *http.Requ
}
func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
preferences, err := p.usecase.GetAllOrgPreferences(
@ -89,9 +88,9 @@ func (p *preferenceAPI) GetAllOrgPreferences(rw http.ResponseWriter, r *http.Req
func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
@ -108,14 +107,14 @@ func (p *preferenceAPI) GetUserPreference(rw http.ResponseWriter, r *http.Reques
func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Request) {
preferenceId := mux.Vars(r)["preferenceId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
req := preferencetypes.UpdatablePreference{}
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
render.Error(rw, err)
@ -131,9 +130,9 @@ func (p *preferenceAPI) UpdateUserPreference(rw http.ResponseWriter, r *http.Req
}
func (p *preferenceAPI) GetAllUserPreferences(rw http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(rw, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, err := authtypes.ClaimsFromContext(r.Context())
if err != nil {
render.Error(rw, err)
return
}
preferences, err := p.usecase.GetAllUserPreferences(

View File

@ -1,19 +1,19 @@
package app
import (
"errors"
"net/http"
"strings"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types/authtypes"
)
func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
RespondError(w, &model.ApiError{Err: errors.New("unauthorized"), Typ: model.ErrorUnauthorized}, nil)
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
req, err := parseSetApdexScoreRequest(r)
@ -31,9 +31,9 @@ func (aH *APIHandler) setApdexSettings(w http.ResponseWriter, r *http.Request) {
func (aH *APIHandler) getApdexSettings(w http.ResponseWriter, r *http.Request) {
services := r.URL.Query().Get("services")
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
RespondError(w, &model.ApiError{Err: errors.New("unauthorized"), Typ: model.ErrorUnauthorized}, nil)
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
apdexSet, err := dao.DB().GetApdexSettings(r.Context(), claims.OrgID, strings.Split(strings.TrimSpace(services), ","))

View File

@ -1,123 +0,0 @@
package app
import (
"context"
"errors"
"net/http"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types"
"github.com/gorilla/mux"
)
type AuthMiddleware struct {
GetUserFromRequest func(r context.Context) (*types.GettableUser, error)
}
func NewAuthMiddleware(f func(ctx context.Context) (*types.GettableUser, error)) *AuthMiddleware {
return &AuthMiddleware{
GetUserFromRequest: f,
}
}
func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
f(w, r)
}
}
func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !(auth.IsViewer(user) || auth.IsEditor(user) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to viewers/editors/admins"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}
func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !(auth.IsEditor(user) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to editors/admins"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}
func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
id := mux.Vars(r)["id"]
if !(auth.IsSelfAccessRequest(user, id) || auth.IsAdmin(user)) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible for self access or to the admins"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}
func (am *AuthMiddleware) AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := am.GetUserFromRequest(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Err: err,
}, nil)
return
}
if !auth.IsAdmin(user) {
RespondError(w, &model.ApiError{
Typ: model.ErrorForbidden,
Err: errors.New("API is accessible to admins only"),
}, nil)
return
}
ctx := context.WithValue(r.Context(), constants.ContextUserKey, user)
r = r.WithContext(ctx)
f(w, r)
}
}

View File

@ -1037,7 +1037,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con
var serviceNameIntervalMap = map[string][]tracedetail.Interval{}
var hasMissingSpans bool
claims, claimsPresent := authtypes.ClaimsFromContext(ctx)
claims, errv2 := authtypes.ClaimsFromContext(ctx)
cachedTraceData, err := r.GetWaterfallSpansForTraceWithMetadataCache(ctx, traceID)
if err == nil {
startTime = cachedTraceData.StartTime
@ -1050,7 +1050,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con
totalErrorSpans = cachedTraceData.TotalErrorSpans
hasMissingSpans = cachedTraceData.HasMissingSpans
if claimsPresent {
if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false)
}
}
@ -1067,7 +1067,7 @@ func (r *ClickHouseReader) GetWaterfallSpansForTraceWithMetadata(ctx context.Con
}
totalSpans = uint64(len(searchScanResponses))
if claimsPresent {
if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_TRACE_DETAIL_API, map[string]interface{}{"traceSize": totalSpans}, claims.Email, true, false)
}
@ -3280,8 +3280,8 @@ func (r *ClickHouseReader) GetLogs(ctx context.Context, params *model.LogsFilter
"lenFilters": lenFilters,
}
if lenFilters != 0 {
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false)
}
}
@ -3322,8 +3322,8 @@ func (r *ClickHouseReader) TailLogs(ctx context.Context, client *model.LogsTailC
"lenFilters": lenFilters,
}
if lenFilters != 0 {
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false)
}
}
@ -3414,8 +3414,8 @@ func (r *ClickHouseReader) AggregateLogs(ctx context.Context, params *model.Logs
"lenFilters": lenFilters,
}
if lenFilters != 0 {
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_LOGS_FILTERS, data, claims.Email, true, false)
}
}
@ -6835,8 +6835,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea
if traceSummary.NumSpans > uint64(params.MaxSpansInTrace) {
zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace),
zap.Uint64("Count", traceSummary.NumSpans))
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
data := map[string]interface{}{
"traceSize": traceSummary.NumSpans,
"maxSpansInTraceLimit": params.MaxSpansInTrace,
@ -6847,8 +6847,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea
return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details")
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
data := map[string]interface{}{
"traceSize": traceSummary.NumSpans,
"algo": "smart",
@ -6937,8 +6937,8 @@ func (r *ClickHouseReader) SearchTracesV2(ctx context.Context, params *model.Sea
}
end = time.Now()
zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start)))
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
data := map[string]interface{}{
"traceSize": len(searchScanResponses),
"spansRenderLimit": params.SpansRenderLimit,
@ -6976,8 +6976,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc
if countSpans > uint64(params.MaxSpansInTrace) {
zap.L().Error("Max spans allowed in a trace limit reached", zap.Int("MaxSpansInTrace", params.MaxSpansInTrace),
zap.Uint64("Count", countSpans))
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
data := map[string]interface{}{
"traceSize": countSpans,
"maxSpansInTraceLimit": params.MaxSpansInTrace,
@ -6988,8 +6988,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc
return nil, fmt.Errorf("max spans allowed in trace limit reached, please contact support for more details")
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
data := map[string]interface{}{
"traceSize": countSpans,
"algo": "smart",
@ -7049,8 +7049,8 @@ func (r *ClickHouseReader) SearchTraces(ctx context.Context, params *model.Searc
}
end = time.Now()
zap.L().Debug("smartTraceAlgo took: ", zap.Duration("duration", end.Sub(start)))
claims, ok := authtypes.ClaimsFromContext(ctx)
if ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 == nil {
data := map[string]interface{}{
"traceSize": len(searchScanResponses),
"spansRenderLimit": params.SpansRenderLimit,

View File

@ -6,12 +6,11 @@ import (
"github.com/SigNoz/signoz/pkg/modules/organization"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/query-service/utils"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
@ -300,13 +299,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
return nil, model.InternalError(err)
}
group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
return nil, model.InternalError(apiErr)
}
auth.InitAuthCache(ctx)
userId := uuid.NewString()
return dao.DB().CreateUser(
ctx,
@ -316,7 +308,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
Email: userId[:8] + "test@test.com",
Password: "test",
OrgID: organization.ID.StringValue(),
GroupID: group.ID,
Role: authtypes.RoleAdmin.String(),
},
true,
)

View File

@ -108,8 +108,8 @@ func CreateView(ctx context.Context, orgID string, view v3.SavedView) (valuer.UU
createdAt := time.Now()
updatedAt := time.Now()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
return valuer.UUID{}, fmt.Errorf("error in getting email from context")
}
@ -177,8 +177,8 @@ func UpdateView(ctx context.Context, orgID string, uuid valuer.UUID, view v3.Sav
return fmt.Errorf("error in marshalling explorer query data: %s", err.Error())
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
return fmt.Errorf("error in getting email from context")
}

View File

@ -21,6 +21,7 @@ import (
"github.com/SigNoz/signoz/pkg/alertmanager"
"github.com/SigNoz/signoz/pkg/apis/fields"
errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/http/render"
"github.com/SigNoz/signoz/pkg/modules/organization"
"github.com/SigNoz/signoz/pkg/modules/preference"
@ -54,7 +55,6 @@ import (
tracesV4 "github.com/SigNoz/signoz/pkg/query-service/app/traces/v4"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/cache"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/contextlinks"
v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3"
"github.com/SigNoz/signoz/pkg/query-service/postprocess"
@ -405,7 +405,7 @@ func writeHttpResponse(w http.ResponseWriter, data interface{}) {
}
}
func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v3").Subrouter()
subRouter.HandleFunc("/autocomplete/aggregate_attributes", am.ViewAccess(
withCacheControl(AutoCompleteCacheControlAge, aH.autocompleteAggregateAttributes))).Methods(http.MethodGet)
@ -431,14 +431,14 @@ func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMid
subRouter.HandleFunc("/logs/livetail", am.ViewAccess(aH.liveTailLogs)).Methods(http.MethodGet)
}
func (aH *APIHandler) RegisterFieldsRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterFieldsRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1").Subrouter()
subRouter.HandleFunc("/fields/keys", am.ViewAccess(aH.FieldsAPI.GetFieldsKeys)).Methods(http.MethodGet)
subRouter.HandleFunc("/fields/values", am.ViewAccess(aH.FieldsAPI.GetFieldsValues)).Methods(http.MethodGet)
}
func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *middleware.AuthZ) {
hostsSubRouter := router.PathPrefix("/api/v1/hosts").Subrouter()
hostsSubRouter.HandleFunc("/attribute_keys", am.ViewAccess(aH.getHostAttributeKeys)).Methods(http.MethodGet)
hostsSubRouter.HandleFunc("/attribute_values", am.ViewAccess(aH.getHostAttributeValues)).Methods(http.MethodGet)
@ -498,12 +498,12 @@ func (aH *APIHandler) RegisterInfraMetricsRoutes(router *mux.Router, am *AuthMid
infraOnboardingSubRouter.HandleFunc("/k8s/status", am.ViewAccess(aH.getK8sInfraOnboardingStatus)).Methods(http.MethodGet)
}
func (aH *APIHandler) RegisterWebSocketPaths(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterWebSocketPaths(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/ws").Subrouter()
subRouter.HandleFunc("/query_progress", am.ViewAccess(aH.GetQueryProgressUpdates)).Methods(http.MethodGet)
}
func (aH *APIHandler) RegisterQueryRangeV4Routes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterQueryRangeV4Routes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v4").Subrouter()
subRouter.HandleFunc("/query_range", am.ViewAccess(aH.QueryRangeV4)).Methods(http.MethodPost)
subRouter.HandleFunc("/metric/metric_metadata", am.ViewAccess(aH.getMetricMetadata)).Methods(http.MethodGet)
@ -520,7 +520,7 @@ func (aH *APIHandler) RegisterPrivateRoutes(router *mux.Router) {
}
// RegisterRoutes registers routes for this handler on the given router
func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *middleware.AuthZ) {
router.HandleFunc("/api/v1/query_range", am.ViewAccess(aH.queryRangeMetrics)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/query", am.ViewAccess(aH.queryMetrics)).Methods(http.MethodGet)
router.HandleFunc("/api/v1/channels", am.ViewAccess(aH.AlertmanagerAPI.ListChannels)).Methods(http.MethodGet)
@ -649,7 +649,7 @@ func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) {
})).Methods(http.MethodGet)
}
func (ah *APIHandler) MetricExplorerRoutes(router *mux.Router, am *AuthMiddleware) {
func (ah *APIHandler) MetricExplorerRoutes(router *mux.Router, am *middleware.AuthZ) {
router.HandleFunc("/api/v1/metrics/filters/keys",
am.ViewAccess(ah.FilterKeysSuggestion)).
Methods(http.MethodGet)
@ -757,9 +757,9 @@ func (aH *APIHandler) PopulateTemporality(ctx context.Context, qp *v3.QueryRange
}
func (aH *APIHandler) listDowntimeSchedules(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -1119,10 +1119,9 @@ func (aH *APIHandler) listRules(w http.ResponseWriter, r *http.Request) {
}
func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
allDashboards, err := dashboards.GetDashboards(r.Context(), claims.OrgID)
@ -1189,11 +1188,10 @@ func (aH *APIHandler) getDashboards(w http.ResponseWriter, r *http.Request) {
}
func (aH *APIHandler) deleteDashboard(w http.ResponseWriter, r *http.Request) {
uuid := mux.Vars(r)["uuid"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
err := dashboards.DeleteDashboard(r.Context(), claims.OrgID, uuid)
@ -1268,7 +1266,6 @@ func (aH *APIHandler) queryDashboardVarsV2(w http.ResponseWriter, r *http.Reques
}
func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) {
uuid := mux.Vars(r)["uuid"]
var postData map[string]interface{}
@ -1283,9 +1280,9 @@ func (aH *APIHandler) updateDashboard(w http.ResponseWriter, r *http.Request) {
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
dashboard, apiError := dashboards.UpdateDashboard(r.Context(), claims.OrgID, claims.Email, uuid, postData)
@ -1302,9 +1299,9 @@ func (aH *APIHandler) getDashboard(w http.ResponseWriter, r *http.Request) {
uuid := mux.Vars(r)["uuid"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
dashboard, apiError := dashboards.GetDashboard(r.Context(), claims.OrgID, uuid)
@ -1356,9 +1353,9 @@ func (aH *APIHandler) createDashboards(w http.ResponseWriter, r *http.Request) {
RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, "Error reading request body")
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
dash, apiErr := dashboards.CreateDashboard(r.Context(), claims.OrgID, claims.Email, postData)
@ -1612,20 +1609,19 @@ func (aH *APIHandler) submitFeedback(w http.ResponseWriter, r *http.Request) {
"email": email,
"message": message,
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_INPRODUCT_FEEDBACK, data, claims.Email, true, false)
}
}
func (aH *APIHandler) registerEvent(w http.ResponseWriter, r *http.Request) {
request, err := parseRegisterEventRequest(r)
if aH.HandleError(w, err, http.StatusBadRequest) {
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 == nil {
switch request.EventType {
case model.TrackEvent:
telemetry.GetInstance().SendEvent(request.EventName, request.Attributes, claims.Email, request.RateLimited, true)
@ -1737,8 +1733,8 @@ func (aH *APIHandler) getServices(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"number": len(*result),
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_NUMBER_OF_SERVICES, data, claims.Email, true, false)
}
@ -1918,8 +1914,8 @@ func (aH *APIHandler) setTTL(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil)
return
}
@ -1946,8 +1942,8 @@ func (aH *APIHandler) getTTL(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil)
return
}
@ -2030,9 +2026,10 @@ func (aH *APIHandler) inviteUser(w http.ResponseWriter, r *http.Request) {
resp, err := auth.Invite(r.Context(), req)
if err != nil {
RespondError(w, &model.ApiError{Err: err, Typ: model.ErrorInternal}, nil)
render.Error(w, err)
return
}
aH.WriteJSON(w, r, resp)
}
@ -2088,11 +2085,10 @@ func (aH *APIHandler) revokeInvite(w http.ResponseWriter, r *http.Request) {
// listPendingInvites is used to list the pending invites.
func (aH *APIHandler) listPendingInvites(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
RespondError(w, &model.ApiError{Err: errors.New("failed to get org id from context"), Typ: model.ErrorInternal}, nil)
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
render.Error(w, errv2)
return
}
invites, err := dao.DB().GetInvites(ctx, claims.OrgID)
@ -2309,18 +2305,13 @@ func (aH *APIHandler) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
adminGroup, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
RespondError(w, apiErr, "Failed to get admin group")
return
}
adminUsers, apiErr := dao.DB().GetUsersByGroup(ctx, adminGroup.ID)
adminUsers, apiErr := dao.DB().GetUsersByRole(ctx, authtypes.RoleAdmin)
if apiErr != nil {
RespondError(w, apiErr, "Failed to get admin group users")
return
}
if user.GroupID == adminGroup.ID && len(adminUsers) == 1 {
if user.Role == authtypes.RoleAdmin.String() && len(adminUsers) == 1 {
RespondError(w, &model.ApiError{
Typ: model.ErrorInternal,
Err: errors.New("cannot delete the last admin user")}, nil)
@ -2332,6 +2323,7 @@ func (aH *APIHandler) deleteUser(w http.ResponseWriter, r *http.Request) {
RespondError(w, err, "Failed to delete user")
return
}
aH.WriteJSON(w, r, map[string]string{"data": "user deleted successfully"})
}
@ -2350,13 +2342,8 @@ func (aH *APIHandler) getRole(w http.ResponseWriter, r *http.Request) {
}, nil)
return
}
group, err := dao.DB().GetGroup(context.Background(), user.GroupID)
if err != nil {
RespondError(w, err, "Failed to get group")
return
}
aH.WriteJSON(w, r, &model.UserRole{UserId: id, GroupName: group.Name})
aH.WriteJSON(w, r, &model.UserRole{UserId: id, GroupName: user.Role})
}
func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
@ -2368,14 +2355,9 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
}
ctx := context.Background()
newGroup, apiErr := dao.DB().GetGroupByName(ctx, req.GroupName)
if apiErr != nil {
RespondError(w, apiErr, "Failed to get user's group")
return
}
if newGroup == nil {
RespondError(w, apiErr, "Specified group is not present")
role, err := authtypes.NewRole(req.GroupName)
if err != nil {
RespondError(w, &model.ApiError{Typ: model.ErrorBadData, Err: errors.New("invalid role")}, nil)
return
}
@ -2386,8 +2368,8 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
}
// Make sure that the request is not demoting the last admin user.
if user.GroupID == auth.AuthCacheObj.AdminGroupId {
adminUsers, apiErr := dao.DB().GetUsersByGroup(ctx, auth.AuthCacheObj.AdminGroupId)
if user.Role == authtypes.RoleAdmin.String() {
adminUsers, apiErr := dao.DB().GetUsersByRole(ctx, authtypes.RoleAdmin)
if apiErr != nil {
RespondError(w, apiErr, "Failed to fetch adminUsers")
return
@ -2401,7 +2383,7 @@ func (aH *APIHandler) editRole(w http.ResponseWriter, r *http.Request) {
}
}
apiErr = dao.DB().UpdateUserGroup(context.Background(), user.ID, newGroup.ID)
apiErr = dao.DB().UpdateUserRole(context.Background(), user.ID, role)
if apiErr != nil {
RespondError(w, apiErr, "Failed to add user to group")
return
@ -2525,7 +2507,7 @@ func (aH *APIHandler) WriteJSON(w http.ResponseWriter, r *http.Request, response
}
// RegisterMessagingQueuesRoutes adds messaging-queues routes
func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *middleware.AuthZ) {
// Main messaging queues router
messagingQueuesRouter := router.PathPrefix("/api/v1/messaging-queues").Subrouter()
@ -2567,7 +2549,7 @@ func (aH *APIHandler) RegisterMessagingQueuesRoutes(router *mux.Router, am *Auth
}
// RegisterThirdPartyApiRoutes adds third-party-api integration routes
func (aH *APIHandler) RegisterThirdPartyApiRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterThirdPartyApiRoutes(router *mux.Router, am *middleware.AuthZ) {
// Main messaging queues router
thirdPartyApiRouter := router.PathPrefix("/api/v1/third-party-apis").Subrouter()
@ -3494,7 +3476,7 @@ func (aH *APIHandler) getAllOrgPreferences(
}
// RegisterIntegrationRoutes Registers all Integrations
func (aH *APIHandler) RegisterIntegrationRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterIntegrationRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1/integrations").Subrouter()
subRouter.HandleFunc(
@ -3526,9 +3508,9 @@ func (aH *APIHandler) ListIntegrations(
for k, values := range r.URL.Query() {
params[k] = values[0]
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3546,9 +3528,9 @@ func (aH *APIHandler) GetIntegration(
w http.ResponseWriter, r *http.Request,
) {
integrationId := mux.Vars(r)["integrationId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
integration, apiErr := aH.IntegrationsController.GetIntegration(
@ -3566,9 +3548,9 @@ func (aH *APIHandler) GetIntegrationConnectionStatus(
w http.ResponseWriter, r *http.Request,
) {
integrationId := mux.Vars(r)["integrationId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
isInstalled, apiErr := aH.IntegrationsController.IsIntegrationInstalled(
@ -3785,9 +3767,9 @@ func (aH *APIHandler) InstallIntegration(
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3813,9 +3795,9 @@ func (aH *APIHandler) UninstallIntegration(
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3829,7 +3811,7 @@ func (aH *APIHandler) UninstallIntegration(
}
// cloud provider integrations
func (aH *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterCloudIntegrationsRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1/cloud-integrations").Subrouter()
subRouter.HandleFunc(
@ -3875,9 +3857,9 @@ func (aH *APIHandler) CloudIntegrationsListConnectedAccounts(
) {
cloudProvider := mux.Vars(r)["cloudProvider"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3903,9 +3885,9 @@ func (aH *APIHandler) CloudIntegrationsGenerateConnectionUrl(
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3927,9 +3909,9 @@ func (aH *APIHandler) CloudIntegrationsGetAccountStatus(
cloudProvider := mux.Vars(r)["cloudProvider"]
accountId := mux.Vars(r)["accountId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3955,9 +3937,9 @@ func (aH *APIHandler) CloudIntegrationsAgentCheckIn(
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -3985,9 +3967,9 @@ func (aH *APIHandler) CloudIntegrationsUpdateAccountConfig(
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4009,9 +3991,9 @@ func (aH *APIHandler) CloudIntegrationsDisconnectAccount(
cloudProvider := mux.Vars(r)["cloudProvider"]
accountId := mux.Vars(r)["accountId"]
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4039,9 +4021,9 @@ func (aH *APIHandler) CloudIntegrationsListServices(
cloudAccountId = &cloudAccountIdQP
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4069,9 +4051,9 @@ func (aH *APIHandler) CloudIntegrationsGetServiceDetails(
cloudAccountId = &cloudAccountIdQP
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4315,9 +4297,9 @@ func (aH *APIHandler) CloudIntegrationsUpdateServiceConfig(
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4334,7 +4316,7 @@ func (aH *APIHandler) CloudIntegrationsUpdateServiceConfig(
}
// logs
func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *AuthMiddleware) {
func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *middleware.AuthZ) {
subRouter := router.PathPrefix("/api/v1/logs").Subrouter()
subRouter.HandleFunc("", am.ViewAccess(aH.getLogs)).Methods(http.MethodGet)
subRouter.HandleFunc("/tail", am.ViewAccess(aH.tailLogs)).Methods(http.MethodGet)
@ -4498,9 +4480,9 @@ func (aH *APIHandler) PreviewLogsPipelinesHandler(w http.ResponseWriter, r *http
}
func (aH *APIHandler) ListLogsPipelinesHandler(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4577,10 +4559,9 @@ func (aH *APIHandler) listLogsPipelinesByVersion(ctx context.Context, orgID stri
}
func (aH *APIHandler) CreateLogsPipeline(w http.ResponseWriter, r *http.Request) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
@ -4622,11 +4603,12 @@ func (aH *APIHandler) getSavedViews(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
category := r.URL.Query().Get("category")
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
queries, err := explorer.GetViewsForFilters(r.Context(), claims.OrgID, sourcePage, name, category)
if err != nil {
RespondError(w, &model.ApiError{Typ: model.ErrorInternal, Err: err}, nil)
@ -4648,9 +4630,9 @@ func (aH *APIHandler) createSavedViews(w http.ResponseWriter, r *http.Request) {
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
uuid, err := explorer.CreateView(r.Context(), claims.OrgID, view)
@ -4670,9 +4652,9 @@ func (aH *APIHandler) getSavedView(w http.ResponseWriter, r *http.Request) {
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
view, err := explorer.GetView(r.Context(), claims.OrgID, viewUUID)
@ -4703,9 +4685,9 @@ func (aH *APIHandler) updateSavedView(w http.ResponseWriter, r *http.Request) {
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
err = explorer.UpdateView(r.Context(), claims.OrgID, viewUUID, view)
@ -4725,9 +4707,9 @@ func (aH *APIHandler) deleteSavedView(w http.ResponseWriter, r *http.Request) {
render.Error(w, errorsV2.Newf(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, err.Error()))
return
}
claims, ok := authtypes.ClaimsFromContext(r.Context())
if !ok {
render.Error(w, errorsV2.Newf(errorsV2.TypeUnauthenticated, errorsV2.CodeUnauthenticated, "unauthenticated"))
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 != nil {
render.Error(w, errv2)
return
}
err = explorer.DeleteView(r.Context(), claims.OrgID, viewUUID)
@ -5050,8 +5032,8 @@ func sendQueryResultEvents(r *http.Request, result []*v3.Result, queryRangeParam
if len(result) > 0 && (len(result[0].Series) > 0 || len(result[0].List) > 0) {
claims, ok := authtypes.ClaimsFromContext(r.Context())
if ok {
claims, errv2 := authtypes.ClaimsFromContext(r.Context())
if errv2 == nil {
queryInfoResult := telemetry.GetInstance().CheckQueryInfo(queryRangeParams)
if queryInfoResult.LogsUsed || queryInfoResult.MetricsUsed || queryInfoResult.TracesUsed {

View File

@ -6,14 +6,13 @@ import (
"testing"
"github.com/SigNoz/signoz/pkg/modules/organization"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model"
v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3"
"github.com/SigNoz/signoz/pkg/query-service/utils"
"github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/types/pipelinetypes"
ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes"
"github.com/google/uuid"
@ -42,13 +41,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
return nil, model.InternalError(err)
}
group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
return nil, model.InternalError(apiErr)
}
auth.InitAuthCache(ctx)
userId := uuid.NewString()
return dao.DB().CreateUser(
ctx,
@ -58,7 +50,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
Email: userId[:8] + "test@test.com",
Password: "test",
OrgID: organization.ID.StringValue(),
GroupID: group.ID,
Role: authtypes.RoleAdmin.String(),
},
true,
)

View File

@ -54,8 +54,8 @@ func (ic *LogParsingPipelineController) ApplyPipelines(
postable []pipelinetypes.PostablePipeline,
) (*PipelinesResponse, *model.ApiError) {
// get user id from context
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
return nil, model.UnauthorizedError(fmt.Errorf("failed to get userId from context"))
}

View File

@ -53,8 +53,8 @@ func (r *Repo) insertPipeline(
))
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
return nil, model.UnauthorizedError(fmt.Errorf("failed to get email from context"))
}

View File

@ -159,9 +159,9 @@ func (receiver *SummaryService) GetMetricsSummary(ctx context.Context, metricNam
g.Go(func() error {
var metricNames []string
metricNames = append(metricNames, metricName)
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return &model.ApiError{Typ: model.ErrorInternal, Err: errors.New("failed to get claims")}
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: errv2}
}
data, err := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames)
if err != nil {
@ -332,9 +332,9 @@ func (receiver *SummaryService) GetRelatedMetrics(ctx context.Context, params *m
alertsRelatedData := make(map[string][]metrics_explorer.Alert)
g.Go(func() error {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return &model.ApiError{Typ: model.ErrorInternal, Err: errors.New("failed to get claims")}
claims, errv2 := authtypes.ClaimsFromContext(ctx)
if errv2 != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: errv2}
}
names, apiError := dashboards.GetDashboardsWithMetricNames(ctx, claims.OrgID, metricNames)
if apiError != nil {

View File

@ -16,6 +16,7 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/kafka"
queues2 "github.com/SigNoz/signoz/pkg/query-service/app/integrations/messagingQueues/queues"
"github.com/SigNoz/signoz/pkg/query-service/app/integrations/thirdPartyApi"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/govaluate"
"github.com/gorilla/mux"
@ -27,7 +28,6 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/app/queryBuilder"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/common"
"github.com/SigNoz/signoz/pkg/query-service/constants"
baseconstants "github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/model"
v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3"
@ -492,14 +492,6 @@ func parseInviteRequest(r *http.Request) (*model.InviteRequest, error) {
return &req, nil
}
func isValidRole(role string) bool {
switch role {
case constants.AdminGroup, constants.EditorGroup, constants.ViewerGroup:
return true
}
return false
}
func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error) {
var req model.BulkInviteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@ -520,7 +512,9 @@ func parseInviteUsersRequest(r *http.Request) (*model.BulkInviteRequest, error)
if req.Users[i].FrontendBaseUrl == "" {
return nil, fmt.Errorf("frontendBaseUrl is required for each user")
}
if !isValidRole(req.Users[i].Role) {
_, err := authtypes.NewRole(req.Users[i].Role)
if err != nil {
return nil, fmt.Errorf("invalid role for user: %s", req.Users[i].Email)
}
}

View File

@ -2,7 +2,6 @@ package app
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@ -30,7 +29,6 @@ import (
"github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/telemetrystore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/SigNoz/signoz/pkg/types/preferencetypes"
"github.com/SigNoz/signoz/pkg/web"
@ -38,7 +36,6 @@ import (
"github.com/soheilhy/cmux"
"github.com/SigNoz/signoz/pkg/query-service/app/explorer"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/cache"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
@ -308,21 +305,7 @@ func (s *Server) createPublicServer(api *APIHandler, web web.Web) (*http.Server,
r.Use(middleware.NewAnalytics(zap.L()).Wrap)
r.Use(middleware.NewLogging(zap.L(), s.serverOptions.Config.APIServer.Logging.ExcludedRoutes).Wrap)
// add auth middleware
getUserFromRequest := func(ctx context.Context) (*types.GettableUser, error) {
user, err := auth.GetUserFromReqContext(ctx)
if err != nil {
return nil, err
}
if user.User.OrgID == "" {
return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims"))
}
return user, nil
}
am := NewAuthMiddleware(getUserFromRequest)
am := middleware.NewAuthZ(s.serverOptions.SigNoz.Instrumentation.Logger())
api.RegisterRoutes(r, am)
api.RegisterLogsRoutes(r, am)

View File

@ -26,18 +26,17 @@ import (
"golang.org/x/crypto/bcrypt"
)
type JwtContextKeyType string
const AccessJwtKey JwtContextKeyType = "accessJwt"
const RefreshJwtKey JwtContextKeyType = "refreshJwt"
const (
opaqueTokenSize = 16
minimumPasswordLength = 8
)
var (
ErrorInvalidCreds = fmt.Errorf("invalid credentials")
ErrorInvalidCreds = fmt.Errorf("invalid credentials")
ErrorEmptyRequest = errors.New("Empty request")
ErrorInvalidRole = errors.New("Invalid role")
ErrorInvalidInviteToken = errors.New("Invalid invite token")
ErrorAskAdmin = errors.New("An invitation is needed to create an account. Please ask your admin (the person who has first installed SIgNoz) to send an invite.")
)
type InviteEmailData struct {
@ -49,7 +48,10 @@ type InviteEmailData struct {
// The root user should be able to invite people to create account on SigNoz cluster.
func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteResponse, error) {
zap.L().Debug("Got an invite request for email", zap.String("email", req.Email))
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
token, err := utils.RandomHex(opaqueTokenSize)
if err != nil {
@ -64,11 +66,6 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
if user != nil {
return nil, errors.New("User already exists with the same email")
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("failed to extract OrgID from context")
}
// Check if an invite already exists
invite, apiErr := dao.DB().GetInviteFromEmail(ctx, req.Email)
if apiErr != nil {
@ -79,8 +76,9 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
return nil, errors.New("An invite already exists for this email")
}
if err := validateInviteRequest(req); err != nil {
return nil, errors.Wrap(err, "invalid invite request")
role, err := authtypes.NewRole(req.Role)
if err != nil {
return nil, err
}
au, apiErr := dao.DB().GetUser(ctx, claims.UserID)
@ -99,7 +97,7 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
Name: req.Name,
Email: req.Email,
Token: token,
Role: req.Role,
Role: role.String(),
OrgID: au.OrgID,
}
@ -120,6 +118,11 @@ func Invite(ctx context.Context, req *model.InviteRequest) (*model.InviteRespons
}
func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.BulkInviteResponse, error) {
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
response := &model.BulkInviteResponse{
Status: "success",
Summary: model.InviteSummary{TotalInvites: len(req.Users)},
@ -127,11 +130,6 @@ func InviteUsers(ctx context.Context, req *model.BulkInviteRequest) (*model.Bulk
FailedInvites: []model.FailedInvite{},
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("failed to extract admin user id")
}
au, apiErr := dao.DB().GetUser(ctx, claims.UserID)
if apiErr != nil {
return nil, errors.Wrap(apiErr.Err, "failed to query admin user from the DB")
@ -191,8 +189,9 @@ func inviteUser(ctx context.Context, req *model.InviteRequest, au *types.Gettabl
return nil, errors.New("An invite already exists for this email")
}
if err := validateInviteRequest(req); err != nil {
return nil, errors.Wrap(err, "invalid invite request")
role, err := authtypes.NewRole(req.Role)
if err != nil {
return nil, err
}
inv := &types.Invite{
@ -206,7 +205,7 @@ func inviteUser(ctx context.Context, req *model.InviteRequest, au *types.Gettabl
Name: req.Name,
Email: req.Email,
Token: token,
Role: req.Role,
Role: role.String(),
OrgID: au.OrgID,
}
@ -260,15 +259,9 @@ func inviteEmail(req *model.InviteRequest, au *types.GettableUser, token string)
// RevokeInvite is used to revoke the invitation for the given email.
func RevokeInvite(ctx context.Context, email string) error {
zap.L().Debug("RevokeInvite method invoked for email", zap.String("email", email))
if !isValidEmail(email) {
return ErrorInvalidInviteToken
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return errors.New("failed to org id from context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return err
}
if err := dao.DB().DeleteInvitation(ctx, claims.OrgID, email); err != nil {
@ -414,19 +407,12 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest, organizationMo
return nil, model.BadRequest(model.ErrPasswordRequired{})
}
groupName := constants.AdminGroup
organization := types.NewOrganization(req.OrgDisplayName)
err := organizationModule.Create(ctx, organization)
if err != nil {
return nil, model.InternalError(err)
}
group, apiErr := dao.DB().GetGroupByName(ctx, groupName)
if apiErr != nil {
zap.L().Error("GetGroupByName failed", zap.Error(apiErr.Err))
return nil, apiErr
}
var hash string
hash, err = PasswordHash(req.Password)
if err != nil {
@ -443,7 +429,7 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest, organizationMo
CreatedAt: time.Now(),
},
ProfilePictureURL: "", // Currently unused
GroupID: group.ID,
Role: authtypes.RoleAdmin.String(),
OrgID: organization.ID.StringValue(),
}
@ -488,13 +474,7 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b
if invite.Role == "" {
// if role is not provided, default to viewer
invite.Role = constants.ViewerGroup
}
group, apiErr := dao.DB().GetGroupByName(ctx, invite.Role)
if apiErr != nil {
zap.L().Error("GetGroupByName failed", zap.Error(apiErr.Err))
return nil, model.InternalError(model.ErrSignupFailed{})
invite.Role = authtypes.RoleViewer.String()
}
var hash string
@ -523,12 +503,12 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b
CreatedAt: time.Now(),
},
ProfilePictureURL: "", // Currently unused
GroupID: group.ID,
Role: invite.Role,
OrgID: invite.OrgID,
}
// TODO(Ahsan): Ideally create user and delete invitation should happen in a txn.
user, apiErr = dao.DB().CreateUser(ctx, user, false)
user, apiErr := dao.DB().CreateUser(ctx, user, false)
if apiErr != nil {
zap.L().Error("CreateUser failed", zap.Error(apiErr.Err))
return nil, apiErr
@ -574,8 +554,6 @@ func Register(ctx context.Context, req *RegisterRequest, alertmanager alertmanag
// Login method returns access and refresh tokens on successful login, else it errors out.
func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT) (*model.LoginResponse, error) {
zap.L().Debug("Login method called for user", zap.String("email", request.Email))
user, err := authenticateLogin(ctx, request, jwt)
if err != nil {
zap.L().Error("Failed to authenticate login request", zap.Error(err))
@ -599,18 +577,6 @@ func Login(ctx context.Context, request *model.LoginRequest, jwt *authtypes.JWT)
}, nil
}
func claimsToUserPayload(claims authtypes.Claims) (*types.GettableUser, error) {
user := &types.GettableUser{
User: types.User{
ID: claims.UserID,
GroupID: claims.GroupID,
Email: claims.Email,
OrgID: claims.OrgID,
},
}
return user, nil
}
// authenticateLogin is responsible for querying the DB and validating the credentials.
func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authtypes.JWT) (*types.GettableUser, error) {
// If refresh token is valid, then simply authorize the login request.
@ -621,13 +587,13 @@ func authenticateLogin(ctx context.Context, req *model.LoginRequest, jwt *authty
return nil, errors.Wrap(err, "failed to parse refresh token")
}
if claims.OrgID == "" {
return nil, model.UnauthorizedError(errors.New("orgId is missing in the claims"))
}
user, err := claimsToUserPayload(claims)
if err != nil {
return nil, errors.Wrap(err, "failed to convert claims to user payload")
user := &types.GettableUser{
User: types.User{
ID: claims.UserID,
Role: claims.Role.String(),
Email: claims.Email,
OrgID: claims.OrgID,
},
}
return user, nil
}
@ -658,18 +624,32 @@ func passwordMatch(hash, password string) bool {
}
func GenerateJWTForUser(user *types.User, jwt *authtypes.JWT) (model.UserJwtObject, error) {
j := model.UserJwtObject{}
var err error
j.AccessJwtExpiry = time.Now().Add(jwt.JwtExpiry).Unix()
j.AccessJwt, err = jwt.AccessToken(user.OrgID, user.ID, user.GroupID, user.Email)
role, err := authtypes.NewRole(user.Role)
if err != nil {
return j, errors.Errorf("failed to encode jwt: %v", err)
return model.UserJwtObject{}, err
}
j.RefreshJwtExpiry = time.Now().Add(jwt.JwtRefresh).Unix()
j.RefreshJwt, err = jwt.RefreshToken(user.OrgID, user.ID, user.GroupID, user.Email)
accessJwt, accessClaims, err := jwt.AccessToken(user.OrgID, user.ID, user.Email, role)
if err != nil {
return j, errors.Errorf("failed to encode jwt: %v", err)
return model.UserJwtObject{}, err
}
return j, nil
refreshJwt, refreshClaims, err := jwt.RefreshToken(user.OrgID, user.ID, user.Email, role)
if err != nil {
return model.UserJwtObject{}, err
}
return model.UserJwtObject{
AccessJwt: accessJwt,
RefreshJwt: refreshJwt,
AccessJwtExpiry: accessClaims.ExpiresAt.Unix(),
RefreshJwtExpiry: refreshClaims.ExpiresAt.Unix(),
}, nil
}
func ValidatePassword(password string) error {
if len(password) < minimumPasswordLength {
return errors.Errorf("Password should be atleast %d characters.", minimumPasswordLength)
}
return nil
}

View File

@ -1,82 +0,0 @@
package auth
import (
"context"
errorsV2 "github.com/SigNoz/signoz/pkg/errors"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/pkg/errors"
)
type Group struct {
GroupID string
GroupName string
}
type AuthCache struct {
AdminGroupId string
EditorGroupId string
ViewerGroupId string
}
var AuthCacheObj AuthCache
// InitAuthCache reads the DB and initialize the auth cache.
func InitAuthCache(ctx context.Context) error {
setGroupId := func(groupName string, dest *string) error {
group, err := dao.DB().GetGroupByName(ctx, groupName)
if err != nil {
return errors.Wrapf(err.Err, "failed to get group %s", groupName)
}
*dest = group.ID
return nil
}
if err := setGroupId(constants.AdminGroup, &AuthCacheObj.AdminGroupId); err != nil {
return err
}
if err := setGroupId(constants.EditorGroup, &AuthCacheObj.EditorGroupId); err != nil {
return err
}
if err := setGroupId(constants.ViewerGroup, &AuthCacheObj.ViewerGroupId); err != nil {
return err
}
return nil
}
func GetUserFromReqContext(ctx context.Context) (*types.GettableUser, error) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errorsV2.New(errorsV2.TypeInvalidInput, errorsV2.CodeInvalidInput, "no claims found in context")
}
user := &types.GettableUser{
User: types.User{
ID: claims.UserID,
GroupID: claims.GroupID,
Email: claims.Email,
OrgID: claims.OrgID,
},
}
return user, nil
}
func IsSelfAccessRequest(user *types.GettableUser, id string) bool { return user.ID == id }
func IsViewer(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.ViewerGroupId }
func IsEditor(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.EditorGroupId }
func IsAdmin(user *types.GettableUser) bool { return user.GroupID == AuthCacheObj.AdminGroupId }
func IsAdminV2(claims authtypes.Claims) bool { return claims.GroupID == AuthCacheObj.AdminGroupId }
func ValidatePassword(password string) error {
if len(password) < minimumPasswordLength {
return errors.Errorf("Password should be atleast %d characters.", minimumPasswordLength)
}
return nil
}

View File

@ -1,43 +0,0 @@
package auth
import (
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/pkg/errors"
)
var (
ErrorEmptyRequest = errors.New("Empty request")
ErrorInvalidEmail = errors.New("Invalid email")
ErrorInvalidRole = errors.New("Invalid role")
ErrorInvalidInviteToken = errors.New("Invalid invite token")
ErrorAskAdmin = errors.New("An invitation is needed to create an account. Please ask your admin (the person who has first installed SIgNoz) to send an invite.")
)
func isValidRole(role string) bool {
switch role {
case constants.AdminGroup, constants.EditorGroup, constants.ViewerGroup:
return true
}
return false
}
func validateInviteRequest(req *model.InviteRequest) error {
if req == nil {
return ErrorEmptyRequest
}
if !isValidEmail(req.Email) {
return ErrorInvalidEmail
}
if !isValidRole(req.Role) {
return ErrorInvalidRole
}
return nil
}
// TODO(Ahsan): Implement check on email semantic.
func isValidEmail(email string) bool {
return true
}

View File

@ -1,7 +0,0 @@
package constants
const (
AdminGroup = "ADMIN"
EditorGroup = "EDITOR"
ViewerGroup = "VIEWER"
)

View File

@ -5,6 +5,7 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/types"
"github.com/SigNoz/signoz/pkg/types/authtypes"
)
type ModelDao interface {
@ -22,13 +23,9 @@ type Queries interface {
GetUsers(ctx context.Context) ([]types.GettableUser, *model.ApiError)
GetUsersWithOpts(ctx context.Context, limit int) ([]types.GettableUser, *model.ApiError)
GetGroup(ctx context.Context, id string) (*types.Group, *model.ApiError)
GetGroupByName(ctx context.Context, name string) (*types.Group, *model.ApiError)
GetGroups(ctx context.Context) ([]types.Group, *model.ApiError)
GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError)
GetUsersByOrg(ctx context.Context, orgId string) ([]types.GettableUser, *model.ApiError)
GetUsersByGroup(ctx context.Context, groupId string) ([]types.GettableUser, *model.ApiError)
GetUsersByRole(ctx context.Context, role authtypes.Role) ([]types.GettableUser, *model.ApiError)
GetApdexSettings(ctx context.Context, orgID string, services []string) ([]types.ApdexSettings, *model.ApiError)
@ -43,14 +40,11 @@ type Mutations interface {
EditUser(ctx context.Context, update *types.User) (*types.User, *model.ApiError)
DeleteUser(ctx context.Context, id string) *model.ApiError
CreateGroup(ctx context.Context, group *types.Group) (*types.Group, *model.ApiError)
DeleteGroup(ctx context.Context, id string) *model.ApiError
CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError
DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError
UpdateUserPassword(ctx context.Context, hash, userId string) *model.ApiError
UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError
UpdateUserRole(ctx context.Context, userId string, role authtypes.Role) *model.ApiError
SetApdexSettings(ctx context.Context, orgID string, set *types.ApdexSettings) *model.ApiError
}

View File

@ -9,7 +9,6 @@ import (
"github.com/SigNoz/signoz/pkg/types"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"go.uber.org/zap"
)
type ModelDaoSqlite struct {
@ -24,12 +23,8 @@ func InitDB(sqlStore sqlstore.SQLStore) (*ModelDaoSqlite, error) {
if err := mds.initializeOrgPreferences(ctx); err != nil {
return nil, err
}
if err := mds.initializeRBAC(ctx); err != nil {
return nil, err
}
telemetry.GetInstance().SetUserCountCallback(mds.GetUserCount)
telemetry.GetInstance().SetUserRoleCallback(mds.GetUserRole)
telemetry.GetInstance().SetGetUsersCallback(mds.GetUsers)
return mds, nil
@ -76,42 +71,3 @@ func (mds *ModelDaoSqlite) initializeOrgPreferences(ctx context.Context) error {
return nil
}
// initializeRBAC creates the ADMIN, EDITOR and VIEWER groups if they are not present.
func (mds *ModelDaoSqlite) initializeRBAC(ctx context.Context) error {
f := func(groupName string) error {
_, err := mds.createGroupIfNotPresent(ctx, groupName)
return errors.Wrap(err, "Failed to create group")
}
if err := f(constants.AdminGroup); err != nil {
return err
}
if err := f(constants.EditorGroup); err != nil {
return err
}
if err := f(constants.ViewerGroup); err != nil {
return err
}
return nil
}
func (mds *ModelDaoSqlite) createGroupIfNotPresent(ctx context.Context,
name string) (*types.Group, error) {
group, err := mds.GetGroupByName(ctx, name)
if err != nil {
return nil, errors.Wrap(err.Err, "Failed to query for root group")
}
if group != nil {
return group, nil
}
zap.L().Debug("group is not found, creating it", zap.String("group_name", name))
group, cErr := mds.CreateGroup(ctx, &types.Group{Name: name})
if cErr != nil {
return nil, cErr.Err
}
return group, nil
}

View File

@ -7,7 +7,7 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/query-service/telemetry"
"github.com/SigNoz/signoz/pkg/types"
"github.com/google/uuid"
"github.com/SigNoz/signoz/pkg/types/authtypes"
"github.com/pkg/errors"
)
@ -163,11 +163,11 @@ func (mds *ModelDaoSqlite) UpdateUserPassword(ctx context.Context, passwordHash,
return nil
}
func (mds *ModelDaoSqlite) UpdateUserGroup(ctx context.Context, userId, groupId string) *model.ApiError {
func (mds *ModelDaoSqlite) UpdateUserRole(ctx context.Context, userId string, role authtypes.Role) *model.ApiError {
_, err := mds.bundb.NewUpdate().
Model(&types.User{}).
Set("group_id = ?", groupId).
Set("role = ?", role).
Where("id = ?", userId).
Exec(ctx)
@ -207,10 +207,8 @@ func (mds *ModelDaoSqlite) GetUser(ctx context.Context,
users := []types.GettableUser{}
query := mds.bundb.NewSelect().
Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id").
ColumnExpr("g.name as role").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("o.name as organization").
Join("JOIN organizations o ON o.id = users.org_id").
Where("users.id = ?", id)
@ -244,10 +242,8 @@ func (mds *ModelDaoSqlite) GetUserByEmail(ctx context.Context,
users := []types.GettableUser{}
query := mds.bundb.NewSelect().
Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id").
ColumnExpr("g.name as role").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("o.name as organization").
Join("JOIN organizations o ON o.id = users.org_id").
Where("users.email = ?", email)
@ -279,10 +275,9 @@ func (mds *ModelDaoSqlite) GetUsersWithOpts(ctx context.Context, limit int) ([]t
query := mds.bundb.NewSelect().
Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id").
ColumnExpr("g.name as role").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("users.role as role").
ColumnExpr("o.name as organization").
Join("JOIN organizations o ON o.id = users.org_id")
if limit > 0 {
@ -303,10 +298,9 @@ func (mds *ModelDaoSqlite) GetUsersByOrg(ctx context.Context,
query := mds.bundb.NewSelect().
Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id").
ColumnExpr("g.name as role").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("users.role as role").
ColumnExpr("o.name as organization").
Join("JOIN organizations o ON o.id = users.org_id").
Where("users.org_id = ?", orgId)
@ -317,19 +311,16 @@ func (mds *ModelDaoSqlite) GetUsersByOrg(ctx context.Context,
return users, nil
}
func (mds *ModelDaoSqlite) GetUsersByGroup(ctx context.Context,
groupId string) ([]types.GettableUser, *model.ApiError) {
func (mds *ModelDaoSqlite) GetUsersByRole(ctx context.Context, role authtypes.Role) ([]types.GettableUser, *model.ApiError) {
users := []types.GettableUser{}
query := mds.bundb.NewSelect().
Table("users").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.group_id").
ColumnExpr("g.name as role").
ColumnExpr("o.display_name as organization").
Join("JOIN groups g ON g.id = users.group_id").
Column("users.id", "users.name", "users.email", "users.password", "users.created_at", "users.profile_picture_url", "users.org_id", "users.role").
ColumnExpr("users.role as role").
ColumnExpr("o.name as organization").
Join("JOIN organizations o ON o.id = users.org_id").
Where("users.group_id = ?", groupId)
Where("users.role = ?", role)
err := query.Scan(ctx, &users)
if err != nil {
@ -338,98 +329,7 @@ func (mds *ModelDaoSqlite) GetUsersByGroup(ctx context.Context,
return users, nil
}
func (mds *ModelDaoSqlite) CreateGroup(ctx context.Context,
group *types.Group) (*types.Group, *model.ApiError) {
group.ID = uuid.NewString()
if _, err := mds.bundb.NewInsert().
Model(group).
Exec(ctx); err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return group, nil
}
func (mds *ModelDaoSqlite) DeleteGroup(ctx context.Context, id string) *model.ApiError {
_, err := mds.bundb.NewDelete().
Model(&types.Group{}).
Where("id = ?", id).
Exec(ctx)
if err != nil {
return &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return nil
}
func (mds *ModelDaoSqlite) GetGroup(ctx context.Context,
id string) (*types.Group, *model.ApiError) {
groups := []types.Group{}
if err := mds.bundb.NewSelect().
Model(&groups).
Where("id = ?", id).
Scan(ctx); err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
if len(groups) > 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: errors.New("Found multiple groups with same ID."),
}
}
if len(groups) == 0 {
return nil, nil
}
return &groups[0], nil
}
func (mds *ModelDaoSqlite) GetGroupByName(ctx context.Context,
name string) (*types.Group, *model.ApiError) {
groups := []types.Group{}
err := mds.bundb.NewSelect().
Model(&groups).
Where("name = ?", name).
Scan(ctx)
if err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
if len(groups) > 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: errors.New("Found multiple groups with same name"),
}
}
if len(groups) == 0 {
return nil, nil
}
return &groups[0], nil
}
// TODO(nitya): should have org id
func (mds *ModelDaoSqlite) GetGroups(ctx context.Context) ([]types.Group, *model.ApiError) {
groups := []types.Group{}
if err := mds.bundb.NewSelect().
Model(&groups).
Scan(ctx); err != nil {
return nil, &model.ApiError{Typ: model.ErrorInternal, Err: err}
}
return groups, nil
}
func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context,
req *types.ResetPasswordRequest) *model.ApiError {
func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context, req *types.ResetPasswordRequest) *model.ApiError {
if _, err := mds.bundb.NewInsert().
Model(req).
@ -439,8 +339,7 @@ func (mds *ModelDaoSqlite) CreateResetPasswordEntry(ctx context.Context,
return nil
}
func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context,
token string) *model.ApiError {
func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context, token string) *model.ApiError {
_, err := mds.bundb.NewDelete().
Model(&types.ResetPasswordRequest{}).
Where("token = ?", token).
@ -452,8 +351,7 @@ func (mds *ModelDaoSqlite) DeleteResetPasswordEntry(ctx context.Context,
return nil
}
func (mds *ModelDaoSqlite) GetResetPasswordEntry(ctx context.Context,
token string) (*types.ResetPasswordRequest, *model.ApiError) {
func (mds *ModelDaoSqlite) GetResetPasswordEntry(ctx context.Context, token string) (*types.ResetPasswordRequest, *model.ApiError) {
entries := []types.ResetPasswordRequest{}
@ -491,14 +389,6 @@ func (mds *ModelDaoSqlite) PrecheckLogin(ctx context.Context, email, sourceUrl s
return resp, nil
}
func (mds *ModelDaoSqlite) GetUserRole(ctx context.Context, groupId string) (string, error) {
role, err := mds.GetGroup(ctx, groupId)
if err != nil || role == nil {
return "", err
}
return role.Name, nil
}
func (mds *ModelDaoSqlite) GetUserCount(ctx context.Context) (int, error) {
users, err := mds.GetUsers(ctx)
if err != nil {

View File

@ -10,7 +10,6 @@ import (
"github.com/SigNoz/signoz/pkg/config/envprovider"
"github.com/SigNoz/signoz/pkg/config/fileprovider"
"github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/signoz"
"github.com/SigNoz/signoz/pkg/types/authtypes"
@ -139,10 +138,6 @@ func main() {
logger.Fatal("Could not start servers", zap.Error(err))
}
if err := auth.InitAuthCache(context.Background()); err != nil {
logger.Fatal("Failed to initialize auth cache", zap.Error(err))
}
signoz.Start(context.Background())
if err := signoz.Wait(context.Background()); err != nil {

View File

@ -332,9 +332,9 @@ func (m *Manager) Stop(ctx context.Context) {
// EditRuleDefinition writes the rule definition to the
// datastore and also updates the rule executor
func (m *Manager) EditRule(ctx context.Context, ruleStr string, idStr string) error {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return errors.New("claims not found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return err
}
ruleUUID, err := valuer.NewUUID(idStr)
@ -469,9 +469,9 @@ func (m *Manager) DeleteRule(ctx context.Context, idStr string) error {
return fmt.Errorf("delete rule received an rule id in invalid format, must be a valid uuid-v7")
}
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return errors.New("claims not found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return err
}
_, err = m.ruleStore.GetStoredRule(ctx, id)
@ -523,9 +523,9 @@ func (m *Manager) deleteTask(taskName string) {
// CreateRule stores rule def into db and also
// starts an executor for the rule
func (m *Manager) CreateRule(ctx context.Context, ruleStr string) (*ruletypes.GettableRule, error) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("claims not found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
parsedRule, err := ruletypes.ParsePostableRule([]byte(ruleStr))
@ -804,9 +804,9 @@ func (m *Manager) ListActiveRules() ([]Rule, error) {
}
func (m *Manager) ListRuleStates(ctx context.Context) (*ruletypes.GettableRules, error) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("claims not found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
// fetch rules from DB
storedRules, err := m.ruleStore.GetStoredRules(ctx, claims.OrgID)
@ -918,9 +918,9 @@ func (m *Manager) syncRuleStateWithTask(ctx context.Context, orgID string, taskN
// - re-deploy or undeploy task as necessary
// - update the patched rule in the DB
func (m *Manager) PatchRule(ctx context.Context, ruleStr string, ruleIdStr string) (*ruletypes.GettableRule, error) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, errors.New("claims not found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, err
}
ruleID, err := valuer.NewUUID(ruleIdStr)
@ -1020,9 +1020,9 @@ func (m *Manager) TestNotification(ctx context.Context, ruleStr string) (int, *m
}
func (m *Manager) GetAlertDetailsForMetricNames(ctx context.Context, metricNames []string) (map[string][]ruletypes.GettableRule, *model.ApiError) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return nil, &model.ApiError{Typ: model.ErrorExec, Err: errors.New("claims not found in context")}
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return nil, &model.ApiError{Typ: model.ErrorExec, Err: err}
}
result := make(map[string][]ruletypes.GettableRule)

View File

@ -206,7 +206,6 @@ type Telemetry struct {
alertsInfoCallback func(ctx context.Context) (*model.AlertsInfo, error)
userCountCallback func(ctx context.Context) (int, error)
userRoleCallback func(ctx context.Context, groupId string) (string, error)
getUsersCallback func(ctx context.Context) ([]types.GettableUser, *model.ApiError)
dashboardsInfoCallback func(ctx context.Context) (*model.DashboardsInfo, error)
savedViewsInfoCallback func(ctx context.Context) (*model.SavedViewsInfo, error)
@ -220,10 +219,6 @@ func (a *Telemetry) SetUserCountCallback(callback func(ctx context.Context) (int
a.userCountCallback = callback
}
func (a *Telemetry) SetUserRoleCallback(callback func(ctx context.Context, groupId string) (string, error)) {
a.userRoleCallback = callback
}
func (a *Telemetry) SetGetUsersCallback(callback func(ctx context.Context) ([]types.GettableUser, *model.ApiError)) {
a.getUsersCallback = callback
}
@ -555,21 +550,12 @@ func (a *Telemetry) IdentifyUser(user *types.User) {
if !a.isTelemetryEnabled() || a.isTelemetryAnonymous() {
return
}
// extract user group from user.groupId
role, _ := a.userRoleCallback(context.Background(), user.GroupID)
if a.saasOperator != nil {
if role != "" {
_ = a.saasOperator.Enqueue(analytics.Identify{
UserId: a.userEmail,
Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email).Set("role", role),
})
} else {
_ = a.saasOperator.Enqueue(analytics.Identify{
UserId: a.userEmail,
Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email),
})
}
_ = a.saasOperator.Enqueue(analytics.Identify{
UserId: a.userEmail,
Traits: analytics.NewTraits().SetName(user.Name).SetEmail(user.Email).Set("role", user.Role),
})
_ = a.saasOperator.Enqueue(analytics.Group{
UserId: a.userEmail,

View File

@ -10,9 +10,9 @@ import (
"testing"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/featureManager"
@ -310,7 +310,7 @@ func NewFilterSuggestionsTestBed(t *testing.T) *FilterSuggestionsTestBed {
router := app.NewRouter()
//add the jwt middleware
router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
am := app.NewAuthMiddleware(auth.GetUserFromReqContext)
am := middleware.NewAuthZ(instrumentationtest.New().Logger())
apiHandler.RegisterRoutes(router, am)
apiHandler.RegisterQueryRangeV3Routes(router, am)

View File

@ -11,9 +11,9 @@ import (
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest"
"github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/featureManager"
"github.com/SigNoz/signoz/pkg/query-service/utils"
@ -373,7 +373,7 @@ func NewCloudIntegrationsTestBed(t *testing.T, testDB sqlstore.SQLStore) *CloudI
router := app.NewRouter()
router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
am := app.NewAuthMiddleware(auth.GetUserFromReqContext)
am := middleware.NewAuthZ(instrumentationtest.New().Logger())
apiHandler.RegisterRoutes(router, am)
apiHandler.RegisterCloudIntegrationsRoutes(router, am)

View File

@ -9,11 +9,11 @@ import (
"time"
"github.com/SigNoz/signoz/pkg/http/middleware"
"github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest"
"github.com/SigNoz/signoz/pkg/modules/organization/implorganization"
"github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/app/cloudintegrations"
"github.com/SigNoz/signoz/pkg/query-service/app/integrations"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/featureManager"
"github.com/SigNoz/signoz/pkg/query-service/model"
@ -580,7 +580,7 @@ func NewIntegrationsTestBed(t *testing.T, testDB sqlstore.SQLStore) *Integration
router := app.NewRouter()
router.Use(middleware.NewAuth(zap.L(), jwt, []string{"Authorization", "Sec-WebSocket-Protocol"}).Wrap)
am := app.NewAuthMiddleware(auth.GetUserFromReqContext)
am := middleware.NewAuthZ(instrumentationtest.New().Logger())
apiHandler.RegisterRoutes(router, am)
apiHandler.RegisterIntegrationRoutes(router, am)

View File

@ -20,7 +20,6 @@ import (
"github.com/SigNoz/signoz/pkg/query-service/app"
"github.com/SigNoz/signoz/pkg/query-service/app/clickhouseReader"
"github.com/SigNoz/signoz/pkg/query-service/auth"
"github.com/SigNoz/signoz/pkg/query-service/constants"
"github.com/SigNoz/signoz/pkg/query-service/dao"
"github.com/SigNoz/signoz/pkg/query-service/model"
"github.com/SigNoz/signoz/pkg/sqlstore"
@ -158,13 +157,6 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
return nil, model.InternalError(err)
}
group, apiErr := dao.DB().GetGroupByName(ctx, constants.AdminGroup)
if apiErr != nil {
return nil, model.InternalError(apiErr)
}
auth.InitAuthCache(ctx)
userId := uuid.NewString()
return dao.DB().CreateUser(
@ -175,7 +167,7 @@ func createTestUser(organizationModule organization.Module) (*types.User, *model
Email: userId[:8] + "test@test.com",
Password: "test",
OrgID: organization.ID.StringValue(),
GroupID: group.ID,
Role: authtypes.RoleAdmin.String(),
},
true,
)

View File

@ -58,6 +58,7 @@ func NewTestSqliteDB(t *testing.T) (sqlStore sqlstore.SQLStore, testDBFilePath s
sqlmigration.NewAddVirtualFieldsFactory(),
sqlmigration.NewUpdateIntegrationsFactory(sqlStore),
sqlmigration.NewUpdateOrganizationsFactory(sqlStore),
sqlmigration.NewDropGroupsFactory(sqlStore),
),
)
if err != nil {

View File

@ -2,7 +2,6 @@ package sqlrulestore
import (
"context"
"errors"
"time"
"github.com/SigNoz/signoz/pkg/sqlstore"
@ -60,10 +59,9 @@ func (r *maintenance) GetPlannedMaintenanceByID(ctx context.Context, id valuer.U
}
func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance) (valuer.UUID, error) {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return valuer.UUID{}, errors.New("no claims found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return valuer.UUID{}, err
}
storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{
@ -100,7 +98,7 @@ func (r *maintenance) CreatePlannedMaintenance(ctx context.Context, maintenance
})
}
err := r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
err = r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
_, err := r.sqlstore.
BunDBCtx(ctx).
NewInsert().
@ -147,9 +145,9 @@ func (r *maintenance) DeletePlannedMaintenance(ctx context.Context, id valuer.UU
}
func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ruletypes.GettablePlannedMaintenance, id valuer.UUID) error {
claims, ok := authtypes.ClaimsFromContext(ctx)
if !ok {
return errors.New("no claims found in context")
claims, err := authtypes.ClaimsFromContext(ctx)
if err != nil {
return err
}
storablePlannedMaintenance := ruletypes.StorablePlannedMaintenance{
@ -186,7 +184,7 @@ func (r *maintenance) EditPlannedMaintenance(ctx context.Context, maintenance ru
})
}
err := r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
err = r.sqlstore.RunInTxCtx(ctx, nil, func(ctx context.Context) error {
_, err := r.sqlstore.
BunDBCtx(ctx).
NewUpdate().

View File

@ -73,6 +73,7 @@ func NewSQLMigrationProviderFactories(sqlstore sqlstore.SQLStore) factory.NamedM
sqlmigration.NewAddVirtualFieldsFactory(),
sqlmigration.NewUpdateIntegrationsFactory(sqlstore),
sqlmigration.NewUpdateOrganizationsFactory(sqlstore),
sqlmigration.NewDropGroupsFactory(sqlstore),
)
}

View File

@ -19,12 +19,13 @@ import (
type SigNoz struct {
*factory.Registry
Cache cache.Cache
Web web.Web
SQLStore sqlstore.SQLStore
TelemetryStore telemetrystore.TelemetryStore
Prometheus prometheus.Prometheus
Alertmanager alertmanager.Alertmanager
Instrumentation instrumentation.Instrumentation
Cache cache.Cache
Web web.Web
SQLStore sqlstore.SQLStore
TelemetryStore telemetrystore.TelemetryStore
Prometheus prometheus.Prometheus
Alertmanager alertmanager.Alertmanager
}
func New(
@ -144,12 +145,13 @@ func New(
}
return &SigNoz{
Registry: registry,
Cache: cache,
Web: web,
SQLStore: sqlstore,
TelemetryStore: telemetrystore,
Prometheus: prometheus,
Alertmanager: alertmanager,
Registry: registry,
Instrumentation: instrumentation,
Cache: cache,
Web: web,
SQLStore: sqlstore,
TelemetryStore: telemetrystore,
Prometheus: prometheus,
Alertmanager: alertmanager,
}, nil
}

View File

@ -87,14 +87,14 @@ func (migration *updateOrganization) Up(ctx context.Context, db *bun.DB) error {
// since organizations, users has created_at as integer instead of timestamp
for _, table := range []string{"organizations", "users", "invites"} {
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil {
if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "created_at"); err != nil {
return err
}
}
// migrate is_anonymous and has_opted_updates to boolean from int
for _, column := range []string{"is_anonymous", "has_opted_updates"} {
if err := migration.store.Dialect().MigrateIntToBoolean(ctx, tx, "organizations", column); err != nil {
if err := migration.store.Dialect().IntToBoolean(ctx, tx, "organizations", column); err != nil {
return err
}
}

View File

@ -66,16 +66,16 @@ func (migration *updatePatAndOrgDomains) Up(ctx context.Context, db *bun.DB) err
}
}
if err := updateOrgId(ctx, tx, "org_domains"); err != nil {
if err := updateOrgId(ctx, tx); err != nil {
return err
}
// change created_at and updated_at from integer to timestamp
for _, table := range []string{"personal_access_tokens", "org_domains"} {
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "created_at"); err != nil {
if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "created_at"); err != nil {
return err
}
if err := migration.store.Dialect().MigrateIntToTimestamp(ctx, tx, table, "updated_at"); err != nil {
if err := migration.store.Dialect().IntToTimestamp(ctx, tx, table, "updated_at"); err != nil {
return err
}
}
@ -96,7 +96,7 @@ func (migration *updatePatAndOrgDomains) Down(ctx context.Context, db *bun.DB) e
return nil
}
func updateOrgId(ctx context.Context, tx bun.Tx, table string) error {
func updateOrgId(ctx context.Context, tx bun.Tx) error {
if _, err := tx.NewCreateTable().
Model(&struct {
bun.BaseModel `bun:"table:org_domains_new"`

View File

@ -359,7 +359,20 @@ func (migration *updateIntegrations) CopyOldCloudIntegrationServicesToNewCloudIn
}
func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID string) error {
user := &types.User{}
type oldUser struct {
bun.BaseModel `bun:"table:users"`
types.TimeAuditable
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull" json:"name"`
Email string `bun:"email,type:text,notnull,unique" json:"email"`
Password string `bun:"password,type:text,notnull" json:"-"`
ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"`
GroupID string `bun:"group_id,type:text,notnull" json:"groupId"`
OrgID string `bun:"org_id,type:text,notnull" json:"orgId"`
}
user := &oldUser{}
err := tx.NewSelect().Model(user).Where("email = ?", "aws-integration@signoz.io").Scan(context.Background())
if err != nil {
if err == sql.ErrNoRows {
@ -374,7 +387,7 @@ func (migration *updateIntegrations) copyOldAwsIntegrationUser(tx bun.IDB, orgID
}
// new user
newUser := &types.User{
newUser := &oldUser{
ID: uuid.New().String(),
TimeAuditable: types.TimeAuditable{
CreatedAt: time.Now(),

View File

@ -0,0 +1,141 @@
package sqlmigration
import (
"context"
"github.com/SigNoz/signoz/pkg/factory"
"github.com/SigNoz/signoz/pkg/sqlstore"
"github.com/SigNoz/signoz/pkg/types"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
)
type dropGroups struct {
sqlstore sqlstore.SQLStore
}
func NewDropGroupsFactory(sqlstore sqlstore.SQLStore) factory.ProviderFactory[SQLMigration, Config] {
return factory.NewProviderFactory(factory.MustNewName("drop_groups"), func(ctx context.Context, providerSettings factory.ProviderSettings, config Config) (SQLMigration, error) {
return newDropGroups(ctx, providerSettings, config, sqlstore)
})
}
func newDropGroups(_ context.Context, _ factory.ProviderSettings, _ Config, sqlstore sqlstore.SQLStore) (SQLMigration, error) {
return &dropGroups{sqlstore: sqlstore}, nil
}
func (migration *dropGroups) Register(migrations *migrate.Migrations) error {
if err := migrations.Register(migration.Up, migration.Down); err != nil {
return err
}
return nil
}
func (migration *dropGroups) Up(ctx context.Context, db *bun.DB) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
type Group struct {
bun.BaseModel `bun:"table:groups"`
types.TimeAuditable
OrgID string `bun:"org_id,type:text"`
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull,unique" json:"name"`
}
type existingUser struct {
bun.BaseModel `bun:"table:users"`
types.TimeAuditable
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull" json:"name"`
Email string `bun:"email,type:text,notnull,unique" json:"email"`
Password string `bun:"password,type:text,notnull" json:"-"`
ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"`
GroupID string `bun:"group_id,type:text,notnull" json:"groupId"`
OrgID string `bun:"org_id,type:text,notnull" json:"orgId"`
}
var existingUsers []*existingUser
if err := tx.
NewSelect().
Model(&existingUsers).
Scan(ctx); err != nil {
return err
}
var groups []*Group
if err := tx.
NewSelect().
Model(&groups).
Scan(ctx); err != nil {
return err
}
groupIDToRoleMap := make(map[string]string)
for _, group := range groups {
groupIDToRoleMap[group.ID] = group.Name
}
roleToUserIDMap := make(map[string][]string)
for _, user := range existingUsers {
roleToUserIDMap[groupIDToRoleMap[user.GroupID]] = append(roleToUserIDMap[groupIDToRoleMap[user.GroupID]], user.ID)
}
if err := migration.sqlstore.Dialect().DropColumnWithForeignKeyConstraint(ctx, tx, new(struct {
bun.BaseModel `bun:"table:users"`
types.TimeAuditable
ID string `bun:"id,pk,type:text"`
Name string `bun:"name,type:text,notnull"`
Email string `bun:"email,type:text,notnull,unique"`
Password string `bun:"password,type:text,notnull"`
ProfilePictureURL string `bun:"profile_picture_url,type:text"`
OrgID string `bun:"org_id,type:text,notnull"`
}), "group_id"); err != nil {
return err
}
if err := migration.sqlstore.Dialect().AddColumn(ctx, tx, "users", "role", "TEXT"); err != nil {
return err
}
for role, userIDs := range roleToUserIDMap {
if _, err := tx.
NewUpdate().
Table("users").
Set("role = ?", role).
Where("id IN (?)", bun.In(userIDs)).
Exec(ctx); err != nil {
return err
}
}
if err := migration.sqlstore.Dialect().AddNotNullDefaultToColumn(ctx, tx, "users", "role", "TEXT", "'VIEWER'"); err != nil {
return err
}
if _, err := tx.
NewDropTable().
Table("groups").
IfExists().
Exec(ctx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (migration *dropGroups) Down(ctx context.Context, db *bun.DB) error {
return nil
}

View File

@ -6,7 +6,6 @@ import (
"github.com/SigNoz/signoz/pkg/factory"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/migrate"
)
@ -66,29 +65,3 @@ func MustNew(
}
return migrations
}
func GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string
var err error
if bun.Dialect().Name() == dialect.SQLite {
err = bun.NewSelect().
ColumnExpr("type").
TableExpr("pragma_table_info(?)", table).
Where("name = ?", column).
Scan(ctx, &columnType)
} else {
err = bun.NewSelect().
ColumnExpr("data_type").
TableExpr("information_schema.columns").
Where("table_name = ?", table).
Where("column_name = ?", column).
Scan(ctx, &columnType)
}
if err != nil {
return "", err
}
return columnType, nil
}

View File

@ -5,33 +5,53 @@ import (
"fmt"
"reflect"
"slices"
"strings"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/uptrace/bun"
)
var (
Identity = "id"
Integer = "INTEGER"
Text = "TEXT"
const (
Identity string = "id"
Integer string = "INTEGER"
Text string = "TEXT"
)
var (
Org = "org"
User = "user"
CloudIntegration = "cloud_integration"
const (
Org string = "org"
User string = "user"
CloudIntegration string = "cloud_integration"
)
var (
OrgReference = `("org_id") REFERENCES "organizations" ("id")`
UserReference = `("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE`
CloudIntegrationReference = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE`
const (
OrgReference string = `("org_id") REFERENCES "organizations" ("id")`
UserReference string = `("user_id") REFERENCES "users" ("id") ON DELETE CASCADE ON UPDATE CASCADE`
CloudIntegrationReference string = `("cloud_integration_id") REFERENCES "cloud_integration" ("id") ON DELETE CASCADE`
)
type dialect struct {
const (
OrgField string = "org_id"
)
type dialect struct{}
func (dialect *dialect) GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string
err := bun.
NewSelect().
ColumnExpr("type").
TableExpr("pragma_table_info(?)", table).
Where("name = ?", column).
Scan(ctx, &columnType)
if err != nil {
return "", err
}
return columnType, nil
}
func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
columnType, err := dialect.GetColumnType(ctx, bun, table, column)
if err != nil {
return err
@ -73,7 +93,7 @@ func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB,
return nil
}
func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
columnExists, err := dialect.ColumnExists(ctx, bun, table, column)
if err != nil {
return err
@ -118,22 +138,6 @@ func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, ta
return nil
}
func (dialect *dialect) GetColumnType(ctx context.Context, bun bun.IDB, table string, column string) (string, error) {
var columnType string
err := bun.
NewSelect().
ColumnExpr("type").
TableExpr("pragma_table_info(?)", table).
Where("name = ?", column).
Scan(ctx, &columnType)
if err != nil {
return "", err
}
return columnType, nil
}
func (dialect *dialect) ColumnExists(ctx context.Context, bun bun.IDB, table string, column string) (bool, error) {
var count int
err := bun.NewSelect().
@ -217,7 +221,6 @@ func (dialect *dialect) DropColumn(ctx context.Context, bun bun.IDB, table strin
}
func (dialect *dialect) TableExists(ctx context.Context, bun bun.IDB, table interface{}) (bool, error) {
count := 0
err := bun.
NewSelect().
@ -423,3 +426,63 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel
return nil
}
func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bunIDB bun.IDB, model interface{}, column string) error {
existingTable := bunIDB.Dialect().Tables().Get(reflect.TypeOf(model))
columnExists, err := dialect.ColumnExists(ctx, bunIDB, existingTable.Name, column)
if err != nil {
return err
}
if !columnExists {
return nil
}
newTableName := existingTable.Name + "_tmp"
// Create the newTmpTable query
createTableQuery := bunIDB.NewCreateTable().Model(model).ModelTableExpr(newTableName)
var columnNames []string
for _, field := range existingTable.Fields {
if field.Name != column {
columnNames = append(columnNames, string(field.SQLName))
}
if field.Name == OrgField {
createTableQuery = createTableQuery.ForeignKey(OrgReference)
}
}
// Disable foreign keys temporarily
if _, err := bunIDB.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
return err
}
if _, err = createTableQuery.Exec(ctx); err != nil {
return err
}
// Copy data from old table to new table
if _, err := bunIDB.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s SELECT %s FROM %s", newTableName, strings.Join(columnNames, ", "), existingTable.Name)); err != nil {
return err
}
_, err = bunIDB.NewDropTable().Table(existingTable.Name).Exec(ctx)
if err != nil {
return err
}
_, err = bunIDB.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME TO %s", newTableName, existingTable.Name))
if err != nil {
return err
}
// Re-enable foreign keys
if _, err := bunIDB.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
return err
}
return nil
}

View File

@ -37,15 +37,42 @@ type SQLStoreHook interface {
}
type SQLDialect interface {
MigrateIntToTimestamp(context.Context, bun.IDB, string, string) error
MigrateIntToBoolean(context.Context, bun.IDB, string, string) error
AddNotNullDefaultToColumn(context.Context, bun.IDB, string, string, string, string) error
// Returns the type of the column for the given table and column.
GetColumnType(context.Context, bun.IDB, string, string) (string, error)
// Migrates an integer column to a timestamp column for the given table and column.
IntToTimestamp(context.Context, bun.IDB, string, string) error
// Migrates an integer column to a boolean column for the given table and column.
IntToBoolean(context.Context, bun.IDB, string, string) error
// Adds a not null default to the given column for the given table, column, columnType and defaultValue.
AddNotNullDefaultToColumn(context.Context, bun.IDB, string, string, string, string) error
// Checks if a column exists in a table for the given table and column.
ColumnExists(context.Context, bun.IDB, string, string) (bool, error)
// Adds a column to a table for the given table, column and columnType.
AddColumn(context.Context, bun.IDB, string, string, string) error
RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error)
// Drops a column from a table for the given table and column.
DropColumn(context.Context, bun.IDB, string, string) error
// Renames a column in a table for the given table, old column name and new column name.
RenameColumn(context.Context, bun.IDB, string, string, string) (bool, error)
// Renames a table and modifies the given model for the given table, old model, new model, references and callback. The old model
// and new model must inherit bun.BaseModel.
RenameTableAndModifyModel(context.Context, bun.IDB, interface{}, interface{}, []string, func(context.Context) error) error
// Updates the primary key for the given table, old model, new model, reference and callback. The old model and new model
// must inherit bun.BaseModel.
UpdatePrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error
// Adds a primary key to the given table, old model, new model, reference and callback. The old model and new model
// must inherit bun.BaseModel.
AddPrimaryKey(context.Context, bun.IDB, interface{}, interface{}, string, func(context.Context) error) error
// Drops the column and the associated foreign key constraint for the given table and column.
DropColumnWithForeignKeyConstraint(context.Context, bun.IDB, interface{}, string) error
}

View File

@ -9,11 +9,11 @@ import (
type dialect struct {
}
func (dialect *dialect) MigrateIntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
func (dialect *dialect) IntToTimestamp(ctx context.Context, bun bun.IDB, table string, column string) error {
return nil
}
func (dialect *dialect) MigrateIntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
func (dialect *dialect) IntToBoolean(ctx context.Context, bun bun.IDB, table string, column string) error {
return nil
}
@ -56,3 +56,7 @@ func (dialect *dialect) AddPrimaryKey(ctx context.Context, bun bun.IDB, oldModel
func (dialect *dialect) IndexExists(ctx context.Context, bun bun.IDB, table string, index string) (bool, error) {
return false, nil
}
func (dialect *dialect) DropColumnWithForeignKeyConstraint(ctx context.Context, bun bun.IDB, model interface{}, column string) error {
return nil
}

View File

@ -0,0 +1,83 @@
package authtypes
import (
"log/slog"
"slices"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/golang-jwt/jwt/v5"
)
var _ jwt.ClaimsValidator = (*Claims)(nil)
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"id"`
Email string `json:"email"`
Role Role `json:"role"`
OrgID string `json:"orgId"`
}
func (c *Claims) Validate() error {
if c.UserID == "" {
return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "id is required")
}
// The problem is that when the "role" field is missing entirely from the JSON (as opposed to being present but empty), the UnmarshalJSON method for Role isn't called at all.
// The JSON decoder just sets the Role field to its zero value ("").
if c.Role == "" {
return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "role is required")
}
if c.OrgID == "" {
return errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "orgId is required")
}
return nil
}
func (c *Claims) LogValue() slog.Value {
return slog.GroupValue(
slog.String("id", c.UserID),
slog.String("email", c.Email),
slog.String("role", c.Role.String()),
slog.String("orgId", c.OrgID),
slog.Time("exp", c.ExpiresAt.Time),
)
}
func (c *Claims) IsViewer() error {
if slices.Contains([]Role{RoleViewer, RoleEditor, RoleAdmin}, c.Role) {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only viewers/editors/admins can access this resource")
}
func (c *Claims) IsEditor() error {
if slices.Contains([]Role{RoleEditor, RoleAdmin}, c.Role) {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only editors/admins can access this resource")
}
func (c *Claims) IsAdmin() error {
if c.Role == RoleAdmin {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only admins can access this resource")
}
func (c *Claims) IsSelfAccess(id string) error {
if c.UserID == id {
return nil
}
if c.Role == RoleAdmin {
return nil
}
return errors.New(errors.TypeForbidden, errors.CodeForbidden, "only the user/admin can access their own resource")
}

View File

@ -2,24 +2,15 @@ package authtypes
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/golang-jwt/jwt/v5"
)
type jwtClaimsKey struct{}
type Claims struct {
jwt.RegisteredClaims
UserID string `json:"id"`
GroupID string `json:"gid"`
Email string `json:"email"`
OrgID string `json:"orgId"`
}
type JWT struct {
JwtSecret string
JwtExpiry time.Duration
@ -34,16 +25,6 @@ func NewJWT(jwtSecret string, jwtExpiry time.Duration, jwtRefresh time.Duration)
}
}
func parseBearerAuth(auth string) (string, bool) {
const prefix = "Bearer "
// Case insensitive prefix match
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", false
}
return auth[len(prefix):], true
}
func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context.Context, error) {
var value string
for _, v := range values {
@ -54,7 +35,7 @@ func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context
}
if value == "" {
return ctx, errors.New("missing Authorization header")
return ctx, errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "missing authorization header")
}
// parse from
@ -73,24 +54,18 @@ func (j *JWT) ContextFromRequest(ctx context.Context, values ...string) (context
}
func (j *JWT) Claims(jwtStr string) (Claims, error) {
token, err := jwt.ParseWithClaims(jwtStr, &Claims{}, func(token *jwt.Token) (interface{}, error) {
claims := Claims{}
_, err := jwt.ParseWithClaims(jwtStr, &claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unknown signing algo: %v", token.Header["alg"])
return nil, errors.Newf(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unrecognized signing algorithm: %s", token.Method.Alg())
}
return []byte(j.JwtSecret), nil
})
if err != nil {
return Claims{}, fmt.Errorf("failed to parse jwt token: %w", err)
return Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to parse jwt token")
}
// Type assertion to retrieve claims from the token
userClaims, ok := token.Claims.(*Claims)
if !ok {
return Claims{}, errors.New("failed to retrieve claims from token")
}
return *userClaims, nil
return claims, nil
}
// NewContextWithClaims attaches individual claims to the context.
@ -106,36 +81,62 @@ func (j *JWT) signToken(claims Claims) (string, error) {
}
// AccessToken creates an access token with the provided claims
func (j *JWT) AccessToken(orgId, userId, groupId, email string) (string, error) {
func (j *JWT) AccessToken(orgId, userId, email string, role Role) (string, Claims, error) {
claims := Claims{
UserID: userId,
GroupID: groupId,
Email: email,
OrgID: orgId,
UserID: userId,
Role: role,
Email: email,
OrgID: orgId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtExpiry)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
return j.signToken(claims)
token, err := j.signToken(claims)
if err != nil {
return "", Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to sign token")
}
return token, claims, nil
}
// RefreshToken creates a refresh token with the provided claims
func (j *JWT) RefreshToken(orgId, userId, groupId, email string) (string, error) {
func (j *JWT) RefreshToken(orgId, userId, email string, role Role) (string, Claims, error) {
claims := Claims{
UserID: userId,
GroupID: groupId,
Email: email,
OrgID: orgId,
UserID: userId,
Role: role,
Email: email,
OrgID: orgId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.JwtRefresh)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
return j.signToken(claims)
token, err := j.signToken(claims)
if err != nil {
return "", Claims{}, errors.Wrapf(err, errors.TypeUnauthenticated, errors.CodeUnauthenticated, "failed to sign token")
}
return token, claims, nil
}
func ClaimsFromContext(ctx context.Context) (Claims, bool) {
func ClaimsFromContext(ctx context.Context) (Claims, error) {
claims, ok := ctx.Value(jwtClaimsKey{}).(Claims)
return claims, ok
if !ok {
return Claims{}, errors.New(errors.TypeUnauthenticated, errors.CodeUnauthenticated, "unauthenticated")
}
return claims, nil
}
func parseBearerAuth(auth string) (string, bool) {
const prefix = "Bearer "
// Case insensitive prefix match
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", false
}
return auth[len(prefix):], true
}

View File

@ -4,35 +4,36 @@ import (
"testing"
"time"
"github.com/SigNoz/signoz/pkg/errors"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)
func TestGetAccessJwt(t *testing.T) {
func TestJwtAccessToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
token, err := jwtService.AccessToken("orgId", "userId", "groupId", "email@example.com")
token, _, err := jwtService.AccessToken("orgId", "userId", "email@example.com", RoleAdmin)
assert.NoError(t, err)
assert.NotEmpty(t, token)
}
func TestGetRefreshJwt(t *testing.T) {
func TestJwtRefreshToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
token, err := jwtService.RefreshToken("orgId", "userId", "groupId", "email@example.com")
token, _, err := jwtService.RefreshToken("orgId", "userId", "email@example.com", RoleAdmin)
assert.NoError(t, err)
assert.NotEmpty(t, token)
}
func TestGetJwtClaims(t *testing.T) {
func TestJwtClaims(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
// Create a valid token
claims := Claims{
UserID: "userId",
GroupID: "groupId",
Email: "email@example.com",
OrgID: "orgId",
UserID: "userId",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
@ -45,29 +46,28 @@ func TestGetJwtClaims(t *testing.T) {
retrievedClaims, err := jwtService.Claims(tokenString)
assert.NoError(t, err)
assert.Equal(t, claims.UserID, retrievedClaims.UserID)
assert.Equal(t, claims.GroupID, retrievedClaims.GroupID)
assert.Equal(t, claims.Role, retrievedClaims.Role)
assert.Equal(t, claims.Email, retrievedClaims.Email)
assert.Equal(t, claims.OrgID, retrievedClaims.OrgID)
}
func TestGetJwtClaimsInvalidToken(t *testing.T) {
func TestJwtClaimsInvalidToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
// Test retrieving claims from an invalid token
_, err := jwtService.Claims("invalid.token.string")
assert.Error(t, err)
assert.Contains(t, err.Error(), "token is malformed")
}
func TestGetJwtClaimsExpiredToken(t *testing.T) {
func TestJwtClaimsExpiredToken(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
// Create an expired token
claims := Claims{
UserID: "userId",
GroupID: "groupId",
Email: "email@example.com",
OrgID: "orgId",
UserID: "userId",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
@ -81,15 +81,15 @@ func TestGetJwtClaimsExpiredToken(t *testing.T) {
assert.Contains(t, err.Error(), "token is expired")
}
func TestGetJwtClaimsInvalidSignature(t *testing.T) {
func TestJwtClaimsInvalidSignature(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
// Create a valid token
claims := Claims{
UserID: "userId",
GroupID: "groupId",
Email: "email@example.com",
OrgID: "orgId",
UserID: "userId",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
@ -106,6 +106,86 @@ func TestGetJwtClaimsInvalidSignature(t *testing.T) {
assert.Contains(t, err.Error(), "signature is invalid")
}
func TestJwtClaimsWithInvalidRole(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "userId",
Role: "INVALID_ROLE",
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestJwtClaimsMissingUserID(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestJwtClaimsMissingRole(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "userId",
Role: "",
Email: "email@example.com",
OrgID: "orgId",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestJwtClaimsMissingOrgID(t *testing.T) {
jwtService := NewJWT("secret", time.Minute, time.Hour)
claims := Claims{
UserID: "userId",
Role: RoleAdmin,
Email: "email@example.com",
OrgID: "",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
},
}
validToken, err := jwtService.signToken(claims)
assert.NoError(t, err)
_, err = jwtService.Claims(validToken)
assert.Error(t, err)
assert.True(t, errors.Ast(err, errors.TypeUnauthenticated))
}
func TestParseBearerAuth(t *testing.T) {
tests := []struct {
auth string

View File

@ -0,0 +1,52 @@
package authtypes
import (
"encoding/json"
"github.com/SigNoz/signoz/pkg/errors"
)
// Do not take inspiration from this. This is a hack to avoid using valuer.String and use upper case strings.
type Role string
const (
RoleAdmin Role = "ADMIN"
RoleEditor Role = "EDITOR"
RoleViewer Role = "VIEWER"
)
func NewRole(role string) (Role, error) {
switch role {
case "ADMIN":
return RoleAdmin, nil
case "EDITOR":
return RoleEditor, nil
case "VIEWER":
return RoleViewer, nil
}
return "", errors.Newf(errors.TypeInvalidInput, errors.CodeInvalidInput, "invalid role: %s", role)
}
func (r Role) String() string {
return string(r)
}
func (r *Role) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
role, err := NewRole(s)
if err != nil {
return err
}
*r = role
return nil
}
func (r Role) MarshalJSON() ([]byte, error) {
return json.Marshal(r.String())
}

View File

@ -16,18 +16,8 @@ type Invite struct {
Role string `bun:"role,type:text,notnull" json:"role"`
}
type Group struct {
bun.BaseModel `bun:"table:groups"`
TimeAuditable
OrgID string `bun:"org_id,type:text"`
ID string `bun:"id,pk,type:text" json:"id"`
Name string `bun:"name,type:text,notnull,unique" json:"name"`
}
type GettableUser struct {
User
Role string `json:"role"`
Organization string `json:"organization"`
}
@ -40,7 +30,7 @@ type User struct {
Email string `bun:"email,type:text,notnull,unique" json:"email"`
Password string `bun:"password,type:text,notnull" json:"-"`
ProfilePictureURL string `bun:"profile_picture_url,type:text" json:"profilePictureURL"`
GroupID string `bun:"group_id,type:text,notnull" json:"groupId"`
Role string `bun:"role,type:text,notnull" json:"role"`
OrgID string `bun:"org_id,type:text,notnull" json:"orgId"`
}

View File

@ -16,6 +16,12 @@ pytest_plugins = [
def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--dev",
action="store",
default=False,
help="Run in dev mode. In this mode, the containers are not torn down after test run and are reused in subsequent test runs.",
)
parser.addoption(
"--sqlstore-provider",
action="store",

View File

@ -0,0 +1,3 @@
from testcontainers.core.config import testcontainers_config as config
config.ryuk_disabled = True

View File

@ -1,3 +1,4 @@
import dataclasses
import os
from typing import Any, Generator
@ -15,10 +16,46 @@ def clickhouse(
network: Network,
zookeeper: types.TestContainerDocker,
request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.TestContainerClickhouse:
"""
Package-scoped fixture for Clickhouse TestContainer.
"""
dev = request.config.getoption("--dev")
if dev:
container = pytestconfig.cache.get("clickhouse.container", None)
env = pytestconfig.cache.get("clickhouse.env", None)
if container and env:
assert isinstance(container, dict)
assert isinstance(env, dict)
test_container = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
container["host_config"]["scheme"],
container["host_config"]["address"],
container["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
container["container_config"]["scheme"],
container["container_config"]["address"],
container["container_config"]["port"],
),
)
connection = clickhouse_driver.connect(
user=env["SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_USERNAME"],
password=env["SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_PASSWORD"],
host=test_container.host_config.address,
port=test_container.host_config.port,
)
return types.TestContainerClickhouse(
container=test_container,
conn=connection,
env=env,
)
version = request.config.getoption("--clickhouse-version")
container = ClickHouseContainer(
@ -91,21 +128,37 @@ def clickhouse(
)
def stop():
if dev:
return
connection.close()
container.stop(delete_volume=True)
request.addfinalizer(stop)
return types.TestContainerClickhouse(
container=container,
host_config=types.TestContainerUrlConfig(
"tcp", container.get_container_host_ip(), container.get_exposed_port(9000)
),
container_config=types.TestContainerUrlConfig(
"tcp", container.get_wrapped_container().name, 9000
cached_clickhouse = types.TestContainerClickhouse(
container=types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
"tcp",
container.get_container_host_ip(),
container.get_exposed_port(9000),
),
container_config=types.TestContainerUrlConfig(
"tcp", container.get_wrapped_container().name, 9000
),
),
conn=connection,
env={
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_DSN": f"tcp://{container.username}:{container.password}@{container.get_wrapped_container().name}:{9000}" # pylint: disable=line-too-long
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_DSN": f"tcp://{container.username}:{container.password}@{container.get_wrapped_container().name}:{9000}", # pylint: disable=line-too-long
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_USERNAME": container.username,
"SIGNOZ_TELEMETRYSTORE_CLICKHOUSE_PASSWORD": container.password,
},
)
if dev:
pytestconfig.cache.set(
"clickhouse.container", dataclasses.asdict(cached_clickhouse.container)
)
pytestconfig.cache.set("clickhouse.env", cached_clickhouse.env)
return cached_clickhouse

View File

@ -1,3 +1,4 @@
import dataclasses
from typing import List
import pytest
@ -14,23 +15,44 @@ from fixtures import types
@pytest.fixture(name="zeus", scope="package")
def zeus(
network: Network, request: pytest.FixtureRequest
) -> types.TestContainerWiremock:
network: Network,
request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.TestContainerDocker:
"""
Package-scoped fixture for running zeus
"""
dev = request.config.getoption("--dev")
if dev:
cached_zeus = pytestconfig.cache.get("zeus", None)
if cached_zeus:
return types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
cached_zeus["host_config"]["scheme"],
cached_zeus["host_config"]["address"],
cached_zeus["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
cached_zeus["container_config"]["scheme"],
cached_zeus["container_config"]["address"],
cached_zeus["container_config"]["port"],
),
)
container = WireMockContainer(image="wiremock/wiremock:2.35.1-1", secure=False)
container.with_network(network)
container.start()
def stop():
if dev:
return
container.stop(delete_volume=True)
request.addfinalizer(stop)
return types.TestContainerWiremock(
container=container,
cached_zeus = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
"http", container.get_container_host_ip(), container.get_exposed_port(8080)
),
@ -39,11 +61,16 @@ def zeus(
),
)
if dev:
pytestconfig.cache.set("zeus", dataclasses.asdict(cached_zeus))
return cached_zeus
@pytest.fixture(name="make_http_mocks", scope="function")
def make_http_mocks():
def _make_http_mocks(container: WireMockContainer, mappings: List[Mapping]):
Config.base_url = container.get_url("__admin")
def _make_http_mocks(container: types.TestContainerDocker, mappings: List[Mapping]):
Config.base_url = container.host_config.get("/__admin")
for mapping in mappings:
Mappings.create_mapping(mapping=mapping)

View File

@ -0,0 +1,10 @@
import logging
def setup_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)
return logger

View File

@ -10,10 +10,17 @@ def migrator(
network: Network,
clickhouse: types.TestContainerClickhouse,
request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> None:
"""
Package-scoped fixture for running schema migrations.
"""
dev = request.config.getoption("--dev")
if dev:
cached_migrator = pytestconfig.cache.get("migrator", None)
if cached_migrator is not None and cached_migrator is True:
return None
version = request.config.getoption("--schema-migrator-version")
client = docker.from_env()
@ -53,3 +60,6 @@ def migrator(
raise RuntimeError("failed to run migrations on clickhouse")
container.remove()
if dev:
pytestconfig.cache.set("migrator", True)

View File

@ -1,18 +1,45 @@
import pytest
from testcontainers.core.container import Network
from fixtures.logger import setup_logger
logger = setup_logger(__name__)
@pytest.fixture(name="network", scope="package")
def network(request: pytest.FixtureRequest) -> Network:
def network(request: pytest.FixtureRequest, pytestconfig: pytest.Config) -> Network:
"""
Package-Scoped fixture for creating a network
"""
nw = Network()
dev = request.config.getoption("--dev")
if dev:
cached_network = pytestconfig.cache.get("network", None)
if cached_network:
logger.info("Using cached Network(%s)", cached_network)
nw.id = cached_network["id"]
nw.name = cached_network["name"]
return nw
nw.create()
def stop():
nw.remove()
dev = request.config.getoption("--dev")
if dev:
logger.info(
"Skipping removal of Network(%s)", {"name": nw.name, "id": nw.id}
)
else:
logger.info("Removing Network(%s)", {"name": nw.name, "id": nw.id})
nw.remove()
request.addfinalizer(stop)
return nw
cached_network = nw
if dev:
pytestconfig.cache.set(
"network", {"name": cached_network.name, "id": cached_network.id}
)
return cached_network

View File

@ -1,3 +1,5 @@
import dataclasses
import psycopg2
import pytest
from testcontainers.core.container import Network
@ -8,11 +10,45 @@ from fixtures import types
@pytest.fixture(name="postgres", scope="package")
def postgres(
network: Network, request: pytest.FixtureRequest
network: Network, request: pytest.FixtureRequest, pytestconfig: pytest.Config
) -> types.TestContainerSQL:
"""
Package-scoped fixture for PostgreSQL TestContainer.
"""
dev = request.config.getoption("--dev")
if dev:
container = pytestconfig.cache.get("postgres.container", None)
env = pytestconfig.cache.get("postgres.env", None)
if container and env:
assert isinstance(container, dict)
assert isinstance(env, dict)
test_container = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
container["host_config"]["scheme"],
container["host_config"]["address"],
container["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
container["container_config"]["scheme"],
container["container_config"]["address"],
container["container_config"]["port"],
),
)
return types.TestContainerSQL(
container=test_container,
conn=psycopg2.connect(
dbname=env["SIGNOZ_SQLSTORE_POSTGRES_DBNAME"],
user=env["SIGNOZ_SQLSTORE_POSTGRES_USER"],
password=env["SIGNOZ_SQLSTORE_POSTGRES_PASSWORD"],
host=test_container.host_config.address,
port=test_container.host_config.port,
),
env=env,
)
version = request.config.getoption("--postgres-version")
container = PostgresContainer(
@ -35,24 +71,39 @@ def postgres(
)
def stop():
if dev:
return
connection.close()
container.stop(delete_volume=True)
request.addfinalizer(stop)
return types.TestContainerSQL(
container=container,
host_config=types.TestContainerUrlConfig(
"postgresql",
container.get_container_host_ip(),
container.get_exposed_port(5432),
),
container_config=types.TestContainerUrlConfig(
"postgresql", container.get_wrapped_container().name, 5432
cached_postgres = types.TestContainerSQL(
container=types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
"postgresql",
container.get_container_host_ip(),
container.get_exposed_port(5432),
),
container_config=types.TestContainerUrlConfig(
"postgresql", container.get_wrapped_container().name, 5432
),
),
conn=connection,
env={
"SIGNOZ_SQLSTORE_PROVIDER": "postgres",
"SIGNOZ_SQLSTORE_POSTGRES_DSN": f"postgresql://{container.username}:{container.password}@{container.get_wrapped_container().name}:{5432}/{container.dbname}", # pylint: disable=line-too-long
"SIGNOZ_SQLSTORE_POSTGRES_DBNAME": container.dbname,
"SIGNOZ_SQLSTORE_POSTGRES_USER": container.username,
"SIGNOZ_SQLSTORE_POSTGRES_PASSWORD": container.password,
},
)
if dev:
pytestconfig.cache.set(
"postgres.container", dataclasses.asdict(cached_postgres.container)
)
pytestconfig.cache.set("postgres.env", cached_postgres.env)
return cached_postgres

View File

@ -1,3 +1,4 @@
import dataclasses
import platform
import time
from http import HTTPStatus
@ -8,20 +9,44 @@ from testcontainers.core.container import DockerContainer, Network
from testcontainers.core.image import DockerImage
from fixtures import types
from fixtures.logger import setup_logger
logger = setup_logger(__name__)
@pytest.fixture(name="signoz", scope="package")
def signoz(
network: Network,
zeus: types.TestContainerWiremock,
zeus: types.TestContainerDocker,
sqlstore: types.TestContainerSQL,
clickhouse: types.TestContainerClickhouse,
request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.SigNoz:
"""
Package-scoped fixture for setting up SigNoz.
"""
dev = request.config.getoption("--dev")
if dev:
cached_signoz = pytestconfig.cache.get("signoz.container", None)
if cached_signoz:
self = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
cached_signoz["host_config"]["scheme"],
cached_signoz["host_config"]["address"],
cached_signoz["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
cached_signoz["container_config"]["scheme"],
cached_signoz["container_config"]["address"],
cached_signoz["container_config"]["port"],
),
)
return types.SigNoz(
self=self, sqlstore=sqlstore, telemetrystore=clickhouse, zeus=zeus
)
# Run the migrations for clickhouse
request.getfixturevalue("migrator")
@ -70,7 +95,7 @@ def signoz(
container.start()
def ready(container: DockerContainer) -> None:
for attempt in range(30):
for attempt in range(5):
try:
response = requests.get(
f"http://{container.get_container_host_ip()}:{container.get_exposed_port(8080)}/api/v1/health", # pylint: disable=line-too-long
@ -78,22 +103,31 @@ def signoz(
)
return response.status_code == HTTPStatus.OK
except Exception: # pylint: disable=broad-exception-caught
print(f"attempt {attempt} at health check failed")
logger.info(
"Attempt %s at readiness check for SigNoz container %s failed, going to retry ...", # pylint: disable=line-too-long
attempt + 1,
container,
)
time.sleep(2)
raise TimeoutError("timeout exceeded while waiting")
ready(container=container)
try:
ready(container=container)
except Exception as e: # pylint: disable=broad-exception-caught
raise e
def stop():
logs = container.get_wrapped_container().logs(tail=100)
print(logs.decode(encoding="utf-8"))
container.stop(delete_volume=True)
if dev:
logger.info("Skipping removal of SigNoz container %s ...", container)
return
else:
logger.info("Removing SigNoz container %s ...", container)
container.stop(delete_volume=True)
request.addfinalizer(stop)
return types.SigNoz(
cached_signoz = types.SigNoz(
self=types.TestContainerDocker(
container=container,
host_config=types.TestContainerUrlConfig(
"http",
container.get_container_host_ip(),
@ -109,3 +143,10 @@ def signoz(
telemetrystore=clickhouse,
zeus=zeus,
)
if dev:
pytestconfig.cache.set(
"signoz.container", dataclasses.asdict(cached_signoz.self)
)
return cached_signoz

View File

@ -11,27 +11,59 @@ ConnectionTuple = namedtuple("ConnectionTuple", "connection config")
@pytest.fixture(name="sqlite", scope="package")
def sqlite(
tmpfs: Generator[types.LegacyPath, Any, None], request: pytest.FixtureRequest
tmpfs: Generator[types.LegacyPath, Any, None],
request: pytest.FixtureRequest,
pytestconfig: pytest.Config,
) -> types.TestContainerSQL:
"""
Package-scoped fixture for SQLite.
"""
dev = request.config.getoption("--dev")
if dev:
container = pytestconfig.cache.get("sqlite.container", None)
env = pytestconfig.cache.get("sqlite.env", None)
if container and env:
assert isinstance(container, dict)
assert isinstance(env, dict)
return types.TestContainerSQL(
container=types.TestContainerDocker(
host_config=None,
container_config=None,
),
conn=sqlite3.connect(
env["SIGNOZ_SQLSTORE_SQLITE_PATH"], check_same_thread=False
),
env=env,
)
tmpdir = tmpfs("sqlite")
path = tmpdir / "signoz.db"
connection = sqlite3.connect(path, check_same_thread=False)
def stop():
if dev:
return
connection.close()
request.addfinalizer(stop)
return types.TestContainerSQL(
None,
host_config=None,
container_config=None,
cached_sqlite = types.TestContainerSQL(
container=types.TestContainerDocker(
host_config=None,
container_config=None,
),
conn=connection,
env={
"SIGNOZ_SQLSTORE_PROVIDER": "sqlite",
"SIGNOZ_SQLSTORE_SQLITE_PATH": str(path),
},
)
if dev:
pytestconfig.cache.set("sqlite.env", cached_sqlite.env)
return cached_sqlite

View File

@ -4,8 +4,6 @@ from urllib.parse import urljoin
import py
from clickhouse_driver.dbapi import Connection
from testcontainers.core.container import DockerContainer
from wiremock.testing.testcontainer import WireMockContainer
LegacyPath = py.path.local
@ -27,29 +25,22 @@ class TestContainerUrlConfig:
@dataclass
class TestContainerDocker:
__test__ = False
container: DockerContainer
host_config: TestContainerUrlConfig
container_config: TestContainerUrlConfig
@dataclass
class TestContainerWiremock(TestContainerDocker):
class TestContainerSQL:
__test__ = False
container: WireMockContainer
@dataclass
class TestContainerSQL(TestContainerDocker):
__test__ = False
container: DockerContainer
container: TestContainerDocker
conn: any
env: Dict[str, str]
@dataclass
class TestContainerClickhouse(TestContainerDocker):
class TestContainerClickhouse:
__test__ = False
container: DockerContainer
container: TestContainerDocker
conn: Connection
env: Dict[str, str]
@ -60,4 +51,4 @@ class SigNoz:
self: TestContainerDocker
sqlstore: TestContainerSQL
telemetrystore: TestContainerClickhouse
zeus: TestContainerWiremock
zeus: TestContainerDocker

View File

@ -1,3 +1,5 @@
import dataclasses
import pytest
from testcontainers.core.container import DockerContainer, Network
@ -6,11 +8,29 @@ from fixtures import types
@pytest.fixture(name="zookeeper", scope="package")
def zookeeper(
network: Network, request: pytest.FixtureRequest
network: Network, request: pytest.FixtureRequest, pytestconfig: pytest.Config
) -> types.TestContainerDocker:
"""
Package-scoped fixture for Zookeeper TestContainer.
"""
dev = request.config.getoption("--dev")
if dev:
cached_zookeeper = pytestconfig.cache.get("zookeeper", None)
if cached_zookeeper:
return types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
cached_zookeeper["host_config"]["scheme"],
cached_zookeeper["host_config"]["address"],
cached_zookeeper["host_config"]["port"],
),
container_config=types.TestContainerUrlConfig(
cached_zookeeper["container_config"]["scheme"],
cached_zookeeper["container_config"]["address"],
cached_zookeeper["container_config"]["port"],
),
)
version = request.config.getoption("--zookeeper-version")
container = DockerContainer(image=f"bitnami/zookeeper:{version}")
@ -21,12 +41,14 @@ def zookeeper(
container.start()
def stop():
if dev:
return
container.stop(delete_volume=True)
request.addfinalizer(stop)
return types.TestContainerDocker(
container=container,
cached_zookeeper = types.TestContainerDocker(
host_config=types.TestContainerUrlConfig(
"tcp",
container.get_container_host_ip(),
@ -38,3 +60,8 @@ def zookeeper(
2181,
),
)
if dev:
pytestconfig.cache.set("zookeeper", dataclasses.asdict(cached_zookeeper))
return cached_zookeeper

View File

@ -27,6 +27,10 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
python_files = "src/**/**.py"
log_cli = true
log_format = "%(asctime)s [%(levelname)s] (%(filename)s:%(lineno)s) %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"
addopts = "-ra"
[tool.pylint.main]
ignore = [".venv"]

View File

@ -3,9 +3,12 @@ from http import HTTPStatus
import requests
from fixtures import types
from fixtures.logger import setup_logger
logger = setup_logger(__name__)
def test_register(signoz: types.SigNoz) -> None:
def test_register(signoz: types.SigNoz, get_jwt_token) -> None:
response = requests.get(signoz.self.host_config.get("/api/v1/version"), timeout=2)
assert response.status_code == HTTPStatus.OK
@ -16,8 +19,8 @@ def test_register(signoz: types.SigNoz) -> None:
json={
"name": "admin",
"orgId": "",
"orgName": "",
"email": "admin@admin.com",
"orgName": "integration.test",
"email": "admin@integration.test",
"password": "password",
},
timeout=2,
@ -28,3 +31,175 @@ def test_register(signoz: types.SigNoz) -> None:
assert response.status_code == HTTPStatus.OK
assert response.json()["setupCompleted"] is True
admin_token = get_jwt_token("admin@integration.test", "password")
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "admin@integration.test"),
None,
)
assert found_user is not None
assert found_user["role"] == "ADMIN"
response = requests.get(
signoz.self.host_config.get(f"/api/v1/rbac/role/{found_user["id"]}"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
assert response.json()["group_name"] == "ADMIN"
def test_invite_and_register(signoz: types.SigNoz, get_jwt_token) -> None:
# Generate an invite token for the editor user
response = requests.post(
signoz.self.host_config.get("/api/v1/invite"),
json={"email": "editor@integration.test", "role": "EDITOR"},
timeout=2,
headers={
"Authorization": f"Bearer {get_jwt_token("admin@integration.test", "password")}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.OK
invite_response = response.json()
assert "email" in invite_response
assert "inviteToken" in invite_response
assert invite_response["email"] == "editor@integration.test"
# Register the editor user using the invite token
response = requests.post(
signoz.self.host_config.get("/api/v1/register"),
json={
"email": "editor@integration.test",
"password": "password",
"name": "editor",
"token": f"{invite_response["inviteToken"]}",
},
timeout=2,
)
assert response.status_code == HTTPStatus.OK
# Verify that the invite token has been deleted
response = requests.get(
signoz.self.host_config.get(
f"/api/v1/invite/{invite_response["inviteToken"]}"
), # pylint: disable=line-too-long
timeout=2,
)
assert response.status_code in (HTTPStatus.NOT_FOUND, HTTPStatus.BAD_REQUEST)
# Verify that an admin endpoint cannot be called by the editor user
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={
"Authorization": f"Bearer {get_jwt_token("editor@integration.test", "password")}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.FORBIDDEN
# Verify that the editor has been created
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={
"Authorization": f"Bearer {get_jwt_token("admin@integration.test", "password")}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "editor@integration.test"),
None,
)
assert found_user is not None
assert found_user["role"] == "EDITOR"
assert found_user["name"] == "editor"
assert found_user["email"] == "editor@integration.test"
def test_revoke_invite_and_register(signoz: types.SigNoz, get_jwt_token) -> None:
admin_token = get_jwt_token("admin@integration.test", "password")
# Generate an invite token for the viewer user
response = requests.post(
signoz.self.host_config.get("/api/v1/invite"),
json={"email": "viewer@integration.test", "role": "VIEWER"},
timeout=2,
headers={
"Authorization": f"Bearer {admin_token}" # pylint: disable=line-too-long
},
)
assert response.status_code == HTTPStatus.OK
invite_response = response.json()
assert "email" in invite_response
assert "inviteToken" in invite_response
response = requests.delete(
signoz.self.host_config.get(f"/api/v1/invite/{invite_response['email']}"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
# Try registering the viewer user with the invite token
response = requests.post(
signoz.self.host_config.get("/api/v1/register"),
json={
"email": "viewer@integration.test",
"password": "password",
"name": "viewer",
"token": f"{invite_response["inviteToken"]}",
},
timeout=2,
)
assert response.status_code in (HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND)
def test_self_access(signoz: types.SigNoz, get_jwt_token) -> None:
admin_token = get_jwt_token("admin@integration.test", "password")
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "editor@integration.test"),
None,
)
response = requests.get(
signoz.self.host_config.get(f"/api/v1/rbac/role/{found_user['id']}"),
timeout=2,
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == HTTPStatus.OK
assert response.json()["group_name"] == "EDITOR"

View File

@ -14,7 +14,7 @@ from fixtures.types import SigNoz
def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None:
make_http_mocks(
signoz.zeus.container,
signoz.zeus,
[
Mapping(
request=MappingRequest(
@ -51,7 +51,7 @@ def test_apply_license(signoz: SigNoz, make_http_mocks, get_jwt_token) -> None:
],
)
access_token = get_jwt_token("admin@admin.com", "password")
access_token = get_jwt_token("admin@integration.test", "password")
response = requests.post(
url=signoz.self.host_config.get("/api/v3/licenses"),

View File

@ -0,0 +1,54 @@
from fixtures import types
import requests
from http import HTTPStatus
def test_api_key(signoz: types.SigNoz, get_jwt_token) -> None:
admin_token = get_jwt_token("admin@integration.test", "password")
response = requests.post(
signoz.self.host_config.get("/api/v1/pats"),
headers={"Authorization": f"Bearer {admin_token}"},
json={
"name": "admin",
"role": "ADMIN",
"expiresInDays": 1,
},
)
assert response.status_code == HTTPStatus.OK
pat_response = response.json()
assert "data" in pat_response
assert "token" in pat_response["data"]
response = requests.get(
signoz.self.host_config.get("/api/v1/user"),
timeout=2,
headers={"SIGNOZ-API-KEY": f"{pat_response["data"]["token"]}"},
)
assert response.status_code == HTTPStatus.OK
user_response = response.json()
found_user = next(
(user for user in user_response if user["email"] == "admin@integration.test"),
None,
)
response = requests.get(
signoz.self.host_config.get("/api/v1/pats"),
headers={"SIGNOZ-API-KEY": f"{pat_response["data"]["token"]}"},
)
assert response.status_code == HTTPStatus.OK
assert "data" in response.json()
found_pat = next(
(pat for pat in response.json()["data"] if pat["userId"] == found_user["id"]),
None,
)
assert found_pat is not None
assert found_pat["userId"] == found_user["id"]
assert found_pat["name"] == "admin"
assert found_pat["role"] == "ADMIN"