From 388ef9453ca1ddb1c6bed03b88b4acd42cab0a95 Mon Sep 17 00:00:00 2001 From: Ahsan Barkati Date: Wed, 15 Feb 2023 01:34:22 +0530 Subject: [PATCH 1/6] Add APIs for PAT --- ee/query-service/app/api/api.go | 5 ++ ee/query-service/app/api/pat.go | 78 +++++++++++++++++++++++++ ee/query-service/dao/interface.go | 7 ++- ee/query-service/dao/sqlite/modelDao.go | 12 +++- ee/query-service/dao/sqlite/pat.go | 44 ++++++++++++++ ee/query-service/model/pat.go | 10 ++++ 6 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 ee/query-service/app/api/pat.go create mode 100644 ee/query-service/dao/sqlite/pat.go create mode 100644 ee/query-service/model/pat.go diff --git a/ee/query-service/app/api/api.go b/ee/query-service/app/api/api.go index 601bed7714..93330c0572 100644 --- a/ee/query-service/app/api/api.go +++ b/ee/query-service/app/api/api.go @@ -122,6 +122,11 @@ func (ah *APIHandler) RegisterRoutes(router *mux.Router) { router.HandleFunc("/api/v1/traces/{traceId}", baseapp.ViewAccess(ah.searchTraces)).Methods(http.MethodGet) router.HandleFunc("/api/v2/metrics/query_range", baseapp.ViewAccess(ah.queryRangeMetricsV2)).Methods(http.MethodPost) + // PAT APIs + router.HandleFunc("/api/v1/pat", baseapp.SelfAccess(ah.createPAT)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/pat", baseapp.SelfAccess(ah.getPATs)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/pat/{id}", baseapp.SelfAccess(ah.deletePAT)).Methods(http.MethodDelete) + ah.APIHandler.RegisterRoutes(router) } diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go new file mode 100644 index 0000000000..1b32ddfde5 --- /dev/null +++ b/ee/query-service/app/api/pat.go @@ -0,0 +1,78 @@ +package api + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + "go.signoz.io/signoz/ee/query-service/model" + "go.signoz.io/signoz/pkg/query-service/auth" + "go.uber.org/zap" +) + +func generatePATToken() string { + // Generate a 32-byte random token. + token := make([]byte, 32) + rand.Read(token) + // Encode the token in base64. + encodedToken := base64.StdEncoding.EncodeToString(token) + return encodedToken +} + +func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() + + req := model.PAT{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + RespondError(w, model.BadRequest(err), nil) + return + } + user, err := auth.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } + + req.UserID = user.Id + req.CreatedAt = time.Now().Unix() + req.Token = generatePATToken() + + zap.S().Infof("Got PAT request: %+v", req) + + if apierr := ah.AppDao().CreatePAT(ctx, &req); apierr != nil { + RespondError(w, apierr, nil) + return + } + + ah.Respond(w, &req) +} + +func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() + user, _ := auth.GetUserFromRequest(r) + zap.S().Infof("Get PATs for user: %+v", user.Id) + pats, apierr := ah.AppDao().ListPATs(ctx, user.Id) + if apierr != nil { + RespondError(w, apierr, nil) + return + } + ah.WriteJSON(w, r, pats) +} + +func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() + id := mux.Vars(r)["id"] + zap.S().Infof("Delete PAT with id: %+v", id) + if apierr := ah.AppDao().DeletePAT(ctx, id); apierr != nil { + RespondError(w, apierr, nil) + return + } + ah.WriteJSON(w, r, map[string]string{"data": "pat deleted successfully"}) +} diff --git a/ee/query-service/dao/interface.go b/ee/query-service/dao/interface.go index a2c9d9d68d..a74fa5c6f2 100644 --- a/ee/query-service/dao/interface.go +++ b/ee/query-service/dao/interface.go @@ -3,6 +3,7 @@ package dao import ( "context" "net/url" + "github.com/google/uuid" "github.com/jmoiron/sqlx" "go.signoz.io/signoz/ee/query-service/model" @@ -24,7 +25,7 @@ type ModelDao interface { CanUsePassword(ctx context.Context, email string) (bool, basemodel.BaseApiError) PrepareSsoRedirect(ctx context.Context, redirectUri, email string) (redirectURL string, apierr basemodel.BaseApiError) GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*model.OrgDomain, error) - + // org domain (auth domains) CRUD ops ListDomains(ctx context.Context, orgId string) ([]model.OrgDomain, basemodel.BaseApiError) GetDomain(ctx context.Context, id uuid.UUID) (*model.OrgDomain, basemodel.BaseApiError) @@ -32,4 +33,8 @@ type ModelDao interface { UpdateDomain(ctx context.Context, domain *model.OrgDomain) basemodel.BaseApiError DeleteDomain(ctx context.Context, id uuid.UUID) basemodel.BaseApiError GetDomainByEmail(ctx context.Context, email string) (*model.OrgDomain, basemodel.BaseApiError) + + CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError + ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError) + DeletePAT(ctx context.Context, id string) basemodel.BaseApiError } diff --git a/ee/query-service/dao/sqlite/modelDao.go b/ee/query-service/dao/sqlite/modelDao.go index 156f6b30e7..9b1d74c034 100644 --- a/ee/query-service/dao/sqlite/modelDao.go +++ b/ee/query-service/dao/sqlite/modelDao.go @@ -48,7 +48,17 @@ func InitDB(dataSourceName string) (*modelDao, error) { updated_at INTEGER, data TEXT NOT NULL, FOREIGN KEY(org_id) REFERENCES organizations(id) - );` + ); + CREATE TABLE IF NOT EXISTS personal_access_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + token TEXT NOT NULL, + name TEXT NOT NULL, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(id) + ); + ` _, err = m.DB().Exec(table_schema) if err != nil { diff --git a/ee/query-service/dao/sqlite/pat.go b/ee/query-service/dao/sqlite/pat.go new file mode 100644 index 0000000000..340af1f854 --- /dev/null +++ b/ee/query-service/dao/sqlite/pat.go @@ -0,0 +1,44 @@ +package sqlite + +import ( + "context" + "fmt" + + "go.signoz.io/signoz/ee/query-service/model" + basemodel "go.signoz.io/signoz/pkg/query-service/model" + "go.uber.org/zap" +) + +func (m *modelDao) CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError { + _, err := m.DB().ExecContext(ctx, + "INSERT INTO personal_access_tokens (user_id, token, name, created_at, expires_at) VALUES ($1, $2, $3, $4, $5)", + p.UserID, + p.Token, + p.Name, + p.CreatedAt, + p.ExpiresAt) + if err != nil { + zap.S().Errorf("Failed to insert PAT in db, err: %v", zap.Error(err)) + return model.InternalError(fmt.Errorf("PAT insertion failed")) + } + return nil +} + +func (m *modelDao) ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError) { + pats := []model.PAT{} + + if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE user_id=?;`, userID); err != nil { + zap.S().Errorf("Failed to fetch PATs for user: %s, err: %v", userID, zap.Error(err)) + return nil, model.InternalError(fmt.Errorf("Failed to fetch PATs")) + } + return pats, nil +} + +func (m *modelDao) DeletePAT(ctx context.Context, id string) basemodel.BaseApiError { + _, err := m.DB().ExecContext(ctx, `DELETE from personal_access_tokens where id=?;`, id) + if err != nil { + zap.S().Errorf("Failed to delete PAT, err: %v", zap.Error(err)) + return model.InternalError(fmt.Errorf("Failed to delete PAT")) + } + return nil +} diff --git a/ee/query-service/model/pat.go b/ee/query-service/model/pat.go new file mode 100644 index 0000000000..f320d0be7c --- /dev/null +++ b/ee/query-service/model/pat.go @@ -0,0 +1,10 @@ +package model + +type PAT struct { + Id string `json:"id" db:"id"` + UserID string `json:"userId" db:"user_id"` + Token string `json:"token" db:"token"` + Name string `json:"name" db:"name"` + CreatedAt int64 `json:"createdAt" db:"created_at"` + ExpiresAt int64 `json:"expiresAt" db:"expires_at"` +} From 96267e2e3a28d1f59fb69220aa10defdf7ecd717 Mon Sep 17 00:00:00 2001 From: Ahsan Barkati Date: Wed, 15 Feb 2023 11:13:33 +0530 Subject: [PATCH 2/6] Add GetPAT function --- ee/query-service/app/api/pat.go | 1 - ee/query-service/dao/interface.go | 1 + ee/query-service/dao/sqlite/pat.go | 22 ++++++++++++++++++++-- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go index 1b32ddfde5..7c6052cffb 100644 --- a/ee/query-service/app/api/pat.go +++ b/ee/query-service/app/api/pat.go @@ -45,7 +45,6 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { req.Token = generatePATToken() zap.S().Infof("Got PAT request: %+v", req) - if apierr := ah.AppDao().CreatePAT(ctx, &req); apierr != nil { RespondError(w, apierr, nil) return diff --git a/ee/query-service/dao/interface.go b/ee/query-service/dao/interface.go index a74fa5c6f2..5bb7c02412 100644 --- a/ee/query-service/dao/interface.go +++ b/ee/query-service/dao/interface.go @@ -35,6 +35,7 @@ type ModelDao interface { GetDomainByEmail(ctx context.Context, email string) (*model.OrgDomain, basemodel.BaseApiError) CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError + GetPAT(ctx context.Context, patID string) (*model.PAT, basemodel.BaseApiError) ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError) DeletePAT(ctx context.Context, id string) basemodel.BaseApiError } diff --git a/ee/query-service/dao/sqlite/pat.go b/ee/query-service/dao/sqlite/pat.go index 340af1f854..fc3c406ecb 100644 --- a/ee/query-service/dao/sqlite/pat.go +++ b/ee/query-service/dao/sqlite/pat.go @@ -29,7 +29,7 @@ func (m *modelDao) ListPATs(ctx context.Context, userID string) ([]model.PAT, ba if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE user_id=?;`, userID); err != nil { zap.S().Errorf("Failed to fetch PATs for user: %s, err: %v", userID, zap.Error(err)) - return nil, model.InternalError(fmt.Errorf("Failed to fetch PATs")) + return nil, model.InternalError(fmt.Errorf("failed to fetch PATs")) } return pats, nil } @@ -38,7 +38,25 @@ func (m *modelDao) DeletePAT(ctx context.Context, id string) basemodel.BaseApiEr _, err := m.DB().ExecContext(ctx, `DELETE from personal_access_tokens where id=?;`, id) if err != nil { zap.S().Errorf("Failed to delete PAT, err: %v", zap.Error(err)) - return model.InternalError(fmt.Errorf("Failed to delete PAT")) + return model.InternalError(fmt.Errorf("failed to delete PAT")) } return nil } + +func (m *modelDao) GetPAT(ctx context.Context, patID string) (*model.PAT, basemodel.BaseApiError) { + pats := []model.PAT{} + + if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE id=?;`, patID); err != nil { + zap.S().Errorf("Failed to fetch PAT with ID: %s, err: %v", patID, zap.Error(err)) + return nil, model.InternalError(fmt.Errorf("failed to fetch PAT")) + } + + if len(pats) != 1 { + return nil, &model.ApiError{ + Typ: model.ErrorInternal, + Err: fmt.Errorf("found zero or multiple PATS with same ID"), + } + } + + return &pats[0], nil +} From 797352583a4266eeec7a0b479c7fff4a5648dcb1 Mon Sep 17 00:00:00 2001 From: Ahsan Barkati Date: Wed, 15 Feb 2023 23:49:03 +0530 Subject: [PATCH 3/6] Create PAT supporting auth middleware --- ee/query-service/app/api/api.go | 42 ++--- ee/query-service/app/api/pat.go | 30 ++- ee/query-service/app/server.go | 53 +++++- ee/query-service/dao/interface.go | 2 +- ee/query-service/dao/sqlite/pat.go | 7 +- pkg/query-service/app/auth.go | 116 ++++++++++++ pkg/query-service/app/http_handler.go | 253 ++++++++------------------ pkg/query-service/app/server.go | 9 +- pkg/query-service/auth/rbac.go | 14 ++ 9 files changed, 320 insertions(+), 206 deletions(-) create mode 100644 pkg/query-service/app/auth.go diff --git a/ee/query-service/app/api/api.go b/ee/query-service/app/api/api.go index 93330c0572..42410b65e7 100644 --- a/ee/query-service/app/api/api.go +++ b/ee/query-service/app/api/api.go @@ -69,65 +69,65 @@ func (ah *APIHandler) CheckFeature(f string) bool { } // RegisterRoutes registers routes for this handler on the given router -func (ah *APIHandler) RegisterRoutes(router *mux.Router) { +func (ah *APIHandler) RegisterRoutes(router *mux.Router, am *baseapp.AuthMiddleware) { // note: add ee override methods first // routes available only in ee version router.HandleFunc("/api/v1/licenses", - baseapp.AdminAccess(ah.listLicenses)). + am.AdminAccess(ah.listLicenses)). Methods(http.MethodGet) router.HandleFunc("/api/v1/licenses", - baseapp.AdminAccess(ah.applyLicense)). + am.AdminAccess(ah.applyLicense)). Methods(http.MethodPost) router.HandleFunc("/api/v1/featureFlags", - baseapp.OpenAccess(ah.getFeatureFlags)). + am.OpenAccess(ah.getFeatureFlags)). Methods(http.MethodGet) router.HandleFunc("/api/v1/loginPrecheck", - baseapp.OpenAccess(ah.precheckLogin)). + am.OpenAccess(ah.precheckLogin)). Methods(http.MethodGet) // paid plans specific routes router.HandleFunc("/api/v1/complete/saml", - baseapp.OpenAccess(ah.receiveSAML)). + am.OpenAccess(ah.receiveSAML)). Methods(http.MethodPost) router.HandleFunc("/api/v1/complete/google", - baseapp.OpenAccess(ah.receiveGoogleAuth)). + am.OpenAccess(ah.receiveGoogleAuth)). Methods(http.MethodGet) router.HandleFunc("/api/v1/orgs/{orgId}/domains", - baseapp.AdminAccess(ah.listDomainsByOrg)). + am.AdminAccess(ah.listDomainsByOrg)). Methods(http.MethodGet) router.HandleFunc("/api/v1/domains", - baseapp.AdminAccess(ah.postDomain)). + am.AdminAccess(ah.postDomain)). Methods(http.MethodPost) router.HandleFunc("/api/v1/domains/{id}", - baseapp.AdminAccess(ah.putDomain)). + am.AdminAccess(ah.putDomain)). Methods(http.MethodPut) router.HandleFunc("/api/v1/domains/{id}", - baseapp.AdminAccess(ah.deleteDomain)). + am.AdminAccess(ah.deleteDomain)). Methods(http.MethodDelete) // base overrides - router.HandleFunc("/api/v1/version", baseapp.OpenAccess(ah.getVersion)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/invite/{token}", baseapp.OpenAccess(ah.getInvite)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/register", baseapp.OpenAccess(ah.registerUser)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/login", baseapp.OpenAccess(ah.loginUser)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/traces/{traceId}", baseapp.ViewAccess(ah.searchTraces)).Methods(http.MethodGet) - router.HandleFunc("/api/v2/metrics/query_range", baseapp.ViewAccess(ah.queryRangeMetricsV2)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/version", am.OpenAccess(ah.getVersion)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/invite/{token}", am.OpenAccess(ah.getInvite)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/register", am.OpenAccess(ah.registerUser)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/login", am.OpenAccess(ah.loginUser)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/traces/{traceId}", am.ViewAccess(ah.searchTraces)).Methods(http.MethodGet) + router.HandleFunc("/api/v2/metrics/query_range", am.ViewAccess(ah.queryRangeMetricsV2)).Methods(http.MethodPost) // PAT APIs - router.HandleFunc("/api/v1/pat", baseapp.SelfAccess(ah.createPAT)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/pat", baseapp.SelfAccess(ah.getPATs)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/pat/{id}", baseapp.SelfAccess(ah.deletePAT)).Methods(http.MethodDelete) + router.HandleFunc("/api/v1/pat", am.OpenAccess(ah.createPAT)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/pat", am.OpenAccess(ah.getPATs)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/pat/{id}", am.OpenAccess(ah.deletePAT)).Methods(http.MethodDelete) - ah.APIHandler.RegisterRoutes(router) + ah.APIHandler.RegisterRoutes(router, am) } diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go index 7c6052cffb..d708ba0606 100644 --- a/ee/query-service/app/api/pat.go +++ b/ee/query-service/app/api/pat.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "fmt" "net/http" "time" @@ -55,7 +56,14 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { ctx := context.Background() - user, _ := auth.GetUserFromRequest(r) + user, err := auth.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } zap.S().Infof("Get PATs for user: %+v", user.Id) pats, apierr := ah.AppDao().ListPATs(ctx, user.Id) if apierr != nil { @@ -68,6 +76,26 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) { ctx := context.Background() id := mux.Vars(r)["id"] + user, err := auth.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } + pat, apierr := ah.AppDao().GetPAT(ctx, id) + if apierr != nil { + RespondError(w, apierr, nil) + return + } + if pat.UserID != user.Id { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: fmt.Errorf("unauthorized PAT delete request"), + }, nil) + return + } zap.S().Infof("Delete PAT with id: %+v", id) if apierr := ah.AppDao().DeletePAT(ctx, id); apierr != nil { RespondError(w, apierr, nil) diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index f97f5e0da1..a634057531 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -10,6 +10,7 @@ import ( "net/http" _ "net/http/pprof" // http profiler "os" + "strings" "time" "github.com/gorilla/handlers" @@ -25,7 +26,9 @@ import ( licensepkg "go.signoz.io/signoz/ee/query-service/license" "go.signoz.io/signoz/ee/query-service/usage" + baseapp "go.signoz.io/signoz/pkg/query-service/app" "go.signoz.io/signoz/pkg/query-service/app/dashboards" + baseauth "go.signoz.io/signoz/pkg/query-service/auth" baseconst "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/healthcheck" basealm "go.signoz.io/signoz/pkg/query-service/integrations/alertManager" @@ -199,17 +202,61 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, }, nil } +func getPATToken(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", fmt.Errorf("authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, error) { r := mux.NewRouter() + getUserFromPAT := func(r *http.Request) (*model.UserPayload, error) { + patToken, err := getPATToken(r) + if err != nil { + return nil, fmt.Errorf("failed to get PAT token in request headers, err: %v", err) + } + ctx := context.Background() + dao := apiHandler.AppDao() + pat, err := dao.GetPAT(ctx, patToken) + if err != nil { + return nil, fmt.Errorf("failed to fetch PAT token from DB, err %v", err) + } + user, apierr := dao.GetUser(ctx, pat.UserID) + if apierr != nil { + return nil, fmt.Errorf("failed to fetch user for PAT from DB, err: %v", apierr) + } + return user, nil + } + + getUserFromRequest := func(r *http.Request) (*model.UserPayload, error) { + user, err := getUserFromPAT(r) + if err == nil && user != nil { + zap.S().Debugf("Found valid PAT user: %+v", user) + return user, nil + } + if err != nil { + zap.S().Debugf("Error while getting user for PAT: %+v", err) + } + return baseauth.GetUserFromRequest(r) + } + am := baseapp.NewAuthMiddleware(getUserFromRequest) r.Use(setTimeoutMiddleware) r.Use(s.analyticsMiddleware) r.Use(loggingMiddleware) - apiHandler.RegisterRoutes(r) - apiHandler.RegisterMetricsRoutes(r) - apiHandler.RegisterLogsRoutes(r) + apiHandler.RegisterRoutes(r, am) + apiHandler.RegisterMetricsRoutes(r, am) + apiHandler.RegisterLogsRoutes(r, am) c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, diff --git a/ee/query-service/dao/interface.go b/ee/query-service/dao/interface.go index 5bb7c02412..da6963f68c 100644 --- a/ee/query-service/dao/interface.go +++ b/ee/query-service/dao/interface.go @@ -35,7 +35,7 @@ type ModelDao interface { GetDomainByEmail(ctx context.Context, email string) (*model.OrgDomain, basemodel.BaseApiError) CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError - GetPAT(ctx context.Context, patID string) (*model.PAT, basemodel.BaseApiError) + GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError) ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError) DeletePAT(ctx context.Context, id string) basemodel.BaseApiError } diff --git a/ee/query-service/dao/sqlite/pat.go b/ee/query-service/dao/sqlite/pat.go index fc3c406ecb..7b3569db66 100644 --- a/ee/query-service/dao/sqlite/pat.go +++ b/ee/query-service/dao/sqlite/pat.go @@ -43,18 +43,17 @@ func (m *modelDao) DeletePAT(ctx context.Context, id string) basemodel.BaseApiEr return nil } -func (m *modelDao) GetPAT(ctx context.Context, patID string) (*model.PAT, basemodel.BaseApiError) { +func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemodel.BaseApiError) { pats := []model.PAT{} - if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE id=?;`, patID); err != nil { - zap.S().Errorf("Failed to fetch PAT with ID: %s, err: %v", patID, zap.Error(err)) + if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE token=?;`, token); err != nil { return nil, model.InternalError(fmt.Errorf("failed to fetch PAT")) } if len(pats) != 1 { return nil, &model.ApiError{ Typ: model.ErrorInternal, - Err: fmt.Errorf("found zero or multiple PATS with same ID"), + Err: fmt.Errorf("found zero or multiple PATs with same token, %s", token), } } diff --git a/pkg/query-service/app/auth.go b/pkg/query-service/app/auth.go new file mode 100644 index 0000000000..ef6541a4f6 --- /dev/null +++ b/pkg/query-service/app/auth.go @@ -0,0 +1,116 @@ +package app + +import ( + "errors" + "net/http" + + "github.com/gorilla/mux" + "go.signoz.io/signoz/pkg/query-service/auth" + "go.signoz.io/signoz/pkg/query-service/model" +) + +type AuthMiddleware struct { + GetUserFromRequest func(r *http.Request) (*model.UserPayload, error) +} + +func NewAuthMiddleware(f func(r *http.Request) (*model.UserPayload, error)) *AuthMiddleware { + return &AuthMiddleware{ + GetUserFromRequest: f, + } +} + +// func (am *AuthMiddleware) GetUserFromRequest(r *http.Request) (*model.UserPayload, error) { +// return auth.GetUserFromRequest(r) +// } + +func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + f(w, r) + } +} + +func (am *AuthMiddleware) ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, err := am.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } + + if !(auth.IsViewer(user) || auth.IsEditor(user) || auth.IsAdmin(user)) { + RespondError(w, &model.ApiError{ + Typ: model.ErrorForbidden, + Err: errors.New("API is accessible to viewers/editors/admins."), + }, nil) + return + } + f(w, r) + } +} + +func (am *AuthMiddleware) EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, err := am.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } + if !(auth.IsEditor(user) || auth.IsAdmin(user)) { + RespondError(w, &model.ApiError{ + Typ: model.ErrorForbidden, + Err: errors.New("API is accessible to editors/admins."), + }, nil) + return + } + f(w, r) + } +} + +func (am *AuthMiddleware) SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, err := am.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } + id := mux.Vars(r)["id"] + if !(auth.IsSelfAccessRequest(user, id) || auth.IsAdmin(user)) { + RespondError(w, &model.ApiError{ + Typ: model.ErrorForbidden, + Err: errors.New("API is accessible for self access or to the admins."), + }, nil) + return + } + f(w, r) + } +} + +func (am *AuthMiddleware) AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, err := am.GetUserFromRequest(r) + if err != nil { + RespondError(w, &model.ApiError{ + Typ: model.ErrorUnauthorized, + Err: err, + }, nil) + return + } + if !auth.IsAdmin(user) { + RespondError(w, &model.ApiError{ + Typ: model.ErrorForbidden, + Err: errors.New("API is accessible to admins only"), + }, nil) + return + } + f(w, r) + } +} diff --git a/pkg/query-service/app/http_handler.go b/pkg/query-service/app/http_handler.go index 06de321612..eeb324f149 100644 --- a/pkg/query-service/app/http_handler.go +++ b/pkg/query-service/app/http_handler.go @@ -229,202 +229,109 @@ func writeHttpResponse(w http.ResponseWriter, data interface{}) { } } -func (aH *APIHandler) RegisterMetricsRoutes(router *mux.Router) { +func (aH *APIHandler) RegisterMetricsRoutes(router *mux.Router, am *AuthMiddleware) { subRouter := router.PathPrefix("/api/v2/metrics").Subrouter() - subRouter.HandleFunc("/query_range", ViewAccess(aH.QueryRangeMetricsV2)).Methods(http.MethodPost) - subRouter.HandleFunc("/autocomplete/list", ViewAccess(aH.metricAutocompleteMetricName)).Methods(http.MethodGet) - subRouter.HandleFunc("/autocomplete/tagKey", ViewAccess(aH.metricAutocompleteTagKey)).Methods(http.MethodGet) - subRouter.HandleFunc("/autocomplete/tagValue", ViewAccess(aH.metricAutocompleteTagValue)).Methods(http.MethodGet) + subRouter.HandleFunc("/query_range", am.ViewAccess(aH.QueryRangeMetricsV2)).Methods(http.MethodPost) + subRouter.HandleFunc("/autocomplete/list", am.ViewAccess(aH.metricAutocompleteMetricName)).Methods(http.MethodGet) + subRouter.HandleFunc("/autocomplete/tagKey", am.ViewAccess(aH.metricAutocompleteTagKey)).Methods(http.MethodGet) + subRouter.HandleFunc("/autocomplete/tagValue", am.ViewAccess(aH.metricAutocompleteTagValue)).Methods(http.MethodGet) } func (aH *APIHandler) Respond(w http.ResponseWriter, data interface{}) { writeHttpResponse(w, data) } -func OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - f(w, r) - } -} - -func ViewAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := auth.GetUserFromRequest(r) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - - if !(auth.IsViewer(user) || auth.IsEditor(user) || auth.IsAdmin(user)) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible to viewers/editors/admins."), - }, nil) - return - } - f(w, r) - } -} - -func EditAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := auth.GetUserFromRequest(r) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - if !(auth.IsEditor(user) || auth.IsAdmin(user)) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible to editors/admins."), - }, nil) - return - } - f(w, r) - } -} - -func SelfAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := auth.GetUserFromRequest(r) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - id := mux.Vars(r)["id"] - if !(auth.IsSelfAccessRequest(user, id) || auth.IsAdmin(user)) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible for self access or to the admins."), - }, nil) - return - } - f(w, r) - } -} - -func AdminAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user, err := auth.GetUserFromRequest(r) - if err != nil { - RespondError(w, &model.ApiError{ - Typ: model.ErrorUnauthorized, - Err: err, - }, nil) - return - } - if !auth.IsAdmin(user) { - RespondError(w, &model.ApiError{ - Typ: model.ErrorForbidden, - Err: errors.New("API is accessible to admins only"), - }, nil) - return - } - f(w, r) - } -} - // RegisterPrivateRoutes registers routes for this handler on the given router func (aH *APIHandler) RegisterPrivateRoutes(router *mux.Router) { router.HandleFunc("/api/v1/channels", aH.listChannels).Methods(http.MethodGet) } // RegisterRoutes registers routes for this handler on the given router -func (aH *APIHandler) RegisterRoutes(router *mux.Router) { - router.HandleFunc("/api/v1/query_range", ViewAccess(aH.queryRangeMetrics)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/query", ViewAccess(aH.queryMetrics)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/channels", ViewAccess(aH.listChannels)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/channels/{id}", ViewAccess(aH.getChannel)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/channels/{id}", AdminAccess(aH.editChannel)).Methods(http.MethodPut) - router.HandleFunc("/api/v1/channels/{id}", AdminAccess(aH.deleteChannel)).Methods(http.MethodDelete) - router.HandleFunc("/api/v1/channels", EditAccess(aH.createChannel)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/testChannel", EditAccess(aH.testChannel)).Methods(http.MethodPost) +func (aH *APIHandler) RegisterRoutes(router *mux.Router, am *AuthMiddleware) { + router.HandleFunc("/api/v1/query_range", am.ViewAccess(aH.queryRangeMetrics)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/query", am.ViewAccess(aH.queryMetrics)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/channels", am.ViewAccess(aH.listChannels)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/channels/{id}", am.ViewAccess(aH.getChannel)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/channels/{id}", am.AdminAccess(aH.editChannel)).Methods(http.MethodPut) + router.HandleFunc("/api/v1/channels/{id}", am.AdminAccess(aH.deleteChannel)).Methods(http.MethodDelete) + router.HandleFunc("/api/v1/channels", am.EditAccess(aH.createChannel)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/testChannel", am.EditAccess(aH.testChannel)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/rules", ViewAccess(aH.listRules)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/rules/{id}", ViewAccess(aH.getRule)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/rules", EditAccess(aH.createRule)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/rules/{id}", EditAccess(aH.editRule)).Methods(http.MethodPut) - router.HandleFunc("/api/v1/rules/{id}", EditAccess(aH.deleteRule)).Methods(http.MethodDelete) - router.HandleFunc("/api/v1/rules/{id}", EditAccess(aH.patchRule)).Methods(http.MethodPatch) - router.HandleFunc("/api/v1/testRule", EditAccess(aH.testRule)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/rules", am.ViewAccess(aH.listRules)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/rules/{id}", am.ViewAccess(aH.getRule)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/rules", am.EditAccess(aH.createRule)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/rules/{id}", am.EditAccess(aH.editRule)).Methods(http.MethodPut) + router.HandleFunc("/api/v1/rules/{id}", am.EditAccess(aH.deleteRule)).Methods(http.MethodDelete) + router.HandleFunc("/api/v1/rules/{id}", am.EditAccess(aH.patchRule)).Methods(http.MethodPatch) + router.HandleFunc("/api/v1/testRule", am.EditAccess(aH.testRule)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/dashboards", ViewAccess(aH.getDashboards)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/dashboards", EditAccess(aH.createDashboards)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/dashboards/grafana", EditAccess(aH.createDashboardsTransform)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/dashboards/{uuid}", ViewAccess(aH.getDashboard)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/dashboards/{uuid}", EditAccess(aH.updateDashboard)).Methods(http.MethodPut) - router.HandleFunc("/api/v1/dashboards/{uuid}", EditAccess(aH.deleteDashboard)).Methods(http.MethodDelete) - router.HandleFunc("/api/v1/variables/query", ViewAccess(aH.queryDashboardVars)).Methods(http.MethodGet) - router.HandleFunc("/api/v2/variables/query", ViewAccess(aH.queryDashboardVarsV2)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/dashboards", am.ViewAccess(aH.getDashboards)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/dashboards", am.EditAccess(aH.createDashboards)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/dashboards/grafana", am.EditAccess(aH.createDashboardsTransform)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/dashboards/{uuid}", am.ViewAccess(aH.getDashboard)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/dashboards/{uuid}", am.EditAccess(aH.updateDashboard)).Methods(http.MethodPut) + router.HandleFunc("/api/v1/dashboards/{uuid}", am.EditAccess(aH.deleteDashboard)).Methods(http.MethodDelete) + router.HandleFunc("/api/v1/variables/query", am.ViewAccess(aH.queryDashboardVars)).Methods(http.MethodGet) + router.HandleFunc("/api/v2/variables/query", am.ViewAccess(aH.queryDashboardVarsV2)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/feedback", OpenAccess(aH.submitFeedback)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/feedback", am.OpenAccess(aH.submitFeedback)).Methods(http.MethodPost) // router.HandleFunc("/api/v1/get_percentiles", aH.getApplicationPercentiles).Methods(http.MethodGet) - router.HandleFunc("/api/v1/services", ViewAccess(aH.getServices)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/services/list", ViewAccess(aH.getServicesList)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/service/overview", ViewAccess(aH.getServiceOverview)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/service/top_operations", ViewAccess(aH.getTopOperations)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/service/top_level_operations", ViewAccess(aH.getServicesTopLevelOps)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/traces/{traceId}", ViewAccess(aH.SearchTraces)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/usage", ViewAccess(aH.getUsage)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/dependency_graph", ViewAccess(aH.dependencyGraph)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/settings/ttl", AdminAccess(aH.setTTL)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/settings/ttl", ViewAccess(aH.getTTL)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/services", am.ViewAccess(aH.getServices)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/services/list", am.ViewAccess(aH.getServicesList)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/service/overview", am.ViewAccess(aH.getServiceOverview)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/service/top_operations", am.ViewAccess(aH.getTopOperations)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/service/top_level_operations", am.ViewAccess(aH.getServicesTopLevelOps)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/traces/{traceId}", am.ViewAccess(aH.SearchTraces)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/usage", am.ViewAccess(aH.getUsage)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/dependency_graph", am.ViewAccess(aH.dependencyGraph)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/settings/ttl", am.AdminAccess(aH.setTTL)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/settings/ttl", am.ViewAccess(aH.getTTL)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/version", OpenAccess(aH.getVersion)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/featureFlags", OpenAccess(aH.getFeatureFlags)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/configs", OpenAccess(aH.getConfigs)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/health", OpenAccess(aH.getHealth)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/version", am.OpenAccess(aH.getVersion)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/featureFlags", am.OpenAccess(aH.getFeatureFlags)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/configs", am.OpenAccess(aH.getConfigs)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/getSpanFilters", ViewAccess(aH.getSpanFilters)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/getTagFilters", ViewAccess(aH.getTagFilters)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/getFilteredSpans", ViewAccess(aH.getFilteredSpans)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/getFilteredSpans/aggregates", ViewAccess(aH.getFilteredSpanAggregates)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/getTagValues", ViewAccess(aH.getTagValues)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/getSpanFilters", am.ViewAccess(aH.getSpanFilters)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/getTagFilters", am.ViewAccess(aH.getTagFilters)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/getFilteredSpans", am.ViewAccess(aH.getFilteredSpans)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/getFilteredSpans/aggregates", am.ViewAccess(aH.getFilteredSpanAggregates)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/getTagValues", am.ViewAccess(aH.getTagValues)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/listErrors", ViewAccess(aH.listErrors)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/countErrors", ViewAccess(aH.countErrors)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/errorFromErrorID", ViewAccess(aH.getErrorFromErrorID)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/errorFromGroupID", ViewAccess(aH.getErrorFromGroupID)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/nextPrevErrorIDs", ViewAccess(aH.getNextPrevErrorIDs)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/listErrors", am.ViewAccess(aH.listErrors)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/countErrors", am.ViewAccess(aH.countErrors)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/errorFromErrorID", am.ViewAccess(aH.getErrorFromErrorID)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/errorFromGroupID", am.ViewAccess(aH.getErrorFromGroupID)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/nextPrevErrorIDs", am.ViewAccess(aH.getNextPrevErrorIDs)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/disks", ViewAccess(aH.getDisks)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/disks", am.ViewAccess(aH.getDisks)).Methods(http.MethodGet) // === Authentication APIs === - router.HandleFunc("/api/v1/invite", AdminAccess(aH.inviteUser)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/invite/{token}", OpenAccess(aH.getInvite)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/invite/{email}", AdminAccess(aH.revokeInvite)).Methods(http.MethodDelete) - router.HandleFunc("/api/v1/invite", AdminAccess(aH.listPendingInvites)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/invite", am.AdminAccess(aH.inviteUser)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/invite/{token}", am.OpenAccess(aH.getInvite)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/invite/{email}", am.AdminAccess(aH.revokeInvite)).Methods(http.MethodDelete) + router.HandleFunc("/api/v1/invite", am.AdminAccess(aH.listPendingInvites)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/register", OpenAccess(aH.registerUser)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/login", OpenAccess(aH.loginUser)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/register", am.OpenAccess(aH.registerUser)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/login", am.OpenAccess(aH.loginUser)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/user", AdminAccess(aH.listUsers)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/user/{id}", SelfAccess(aH.getUser)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/user/{id}", SelfAccess(aH.editUser)).Methods(http.MethodPut) - router.HandleFunc("/api/v1/user/{id}", AdminAccess(aH.deleteUser)).Methods(http.MethodDelete) + router.HandleFunc("/api/v1/user", am.AdminAccess(aH.listUsers)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/user/{id}", am.SelfAccess(aH.getUser)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/user/{id}", am.SelfAccess(aH.editUser)).Methods(http.MethodPut) + router.HandleFunc("/api/v1/user/{id}", am.AdminAccess(aH.deleteUser)).Methods(http.MethodDelete) - router.HandleFunc("/api/v1/user/{id}/flags", SelfAccess(aH.patchUserFlag)).Methods(http.MethodPatch) + router.HandleFunc("/api/v1/user/{id}/flags", am.SelfAccess(aH.patchUserFlag)).Methods(http.MethodPatch) - router.HandleFunc("/api/v1/rbac/role/{id}", SelfAccess(aH.getRole)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/rbac/role/{id}", AdminAccess(aH.editRole)).Methods(http.MethodPut) + router.HandleFunc("/api/v1/rbac/role/{id}", am.SelfAccess(aH.getRole)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/rbac/role/{id}", am.AdminAccess(aH.editRole)).Methods(http.MethodPut) - router.HandleFunc("/api/v1/org", AdminAccess(aH.getOrgs)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/org/{id}", AdminAccess(aH.getOrg)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/org/{id}", AdminAccess(aH.editOrg)).Methods(http.MethodPut) - router.HandleFunc("/api/v1/orgUsers/{id}", AdminAccess(aH.getOrgUsers)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/org", am.AdminAccess(aH.getOrgs)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/org/{id}", am.AdminAccess(aH.getOrg)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/org/{id}", am.AdminAccess(aH.editOrg)).Methods(http.MethodPut) + router.HandleFunc("/api/v1/orgUsers/{id}", am.AdminAccess(aH.getOrgUsers)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/getResetPasswordToken/{id}", AdminAccess(aH.getResetPasswordToken)).Methods(http.MethodGet) - router.HandleFunc("/api/v1/resetPassword", OpenAccess(aH.resetPassword)).Methods(http.MethodPost) - router.HandleFunc("/api/v1/changePassword/{id}", SelfAccess(aH.changePassword)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/getResetPasswordToken/{id}", am.AdminAccess(aH.getResetPasswordToken)).Methods(http.MethodGet) + router.HandleFunc("/api/v1/resetPassword", am.OpenAccess(aH.resetPassword)).Methods(http.MethodPost) + router.HandleFunc("/api/v1/changePassword/{id}", am.SelfAccess(aH.changePassword)).Methods(http.MethodPost) } func Intersection(a, b []int) (c []int) { @@ -2221,13 +2128,13 @@ func (aH *APIHandler) WriteJSON(w http.ResponseWriter, r *http.Request, response } // logs -func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router) { +func (aH *APIHandler) RegisterLogsRoutes(router *mux.Router, am *AuthMiddleware) { subRouter := router.PathPrefix("/api/v1/logs").Subrouter() - subRouter.HandleFunc("", ViewAccess(aH.getLogs)).Methods(http.MethodGet) - subRouter.HandleFunc("/tail", ViewAccess(aH.tailLogs)).Methods(http.MethodGet) - subRouter.HandleFunc("/fields", ViewAccess(aH.logFields)).Methods(http.MethodGet) - subRouter.HandleFunc("/fields", EditAccess(aH.logFieldUpdate)).Methods(http.MethodPost) - subRouter.HandleFunc("/aggregate", ViewAccess(aH.logAggregate)).Methods(http.MethodGet) + subRouter.HandleFunc("", am.ViewAccess(aH.getLogs)).Methods(http.MethodGet) + subRouter.HandleFunc("/tail", am.ViewAccess(aH.tailLogs)).Methods(http.MethodGet) + subRouter.HandleFunc("/fields", am.ViewAccess(aH.logFields)).Methods(http.MethodGet) + subRouter.HandleFunc("/fields", am.EditAccess(aH.logFieldUpdate)).Methods(http.MethodPost) + subRouter.HandleFunc("/aggregate", am.ViewAccess(aH.logAggregate)).Methods(http.MethodGet) } func (aH *APIHandler) logFields(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index c9505b9391..c99046fac4 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -20,6 +20,7 @@ import ( "github.com/soheilhy/cmux" "go.signoz.io/signoz/pkg/query-service/app/clickhouseReader" "go.signoz.io/signoz/pkg/query-service/app/dashboards" + "go.signoz.io/signoz/pkg/query-service/auth" "go.signoz.io/signoz/pkg/query-service/constants" "go.signoz.io/signoz/pkg/query-service/dao" "go.signoz.io/signoz/pkg/query-service/featureManager" @@ -176,9 +177,11 @@ func (s *Server) createPublicServer(api *APIHandler) (*http.Server, error) { r.Use(s.analyticsMiddleware) r.Use(loggingMiddleware) - api.RegisterRoutes(r) - api.RegisterMetricsRoutes(r) - api.RegisterLogsRoutes(r) + am := NewAuthMiddleware(auth.GetUserFromRequest) + + api.RegisterRoutes(r, am) + api.RegisterMetricsRoutes(r, am) + api.RegisterLogsRoutes(r, am) c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, diff --git a/pkg/query-service/auth/rbac.go b/pkg/query-service/auth/rbac.go index 44f65576ed..38ae283bb2 100644 --- a/pkg/query-service/auth/rbac.go +++ b/pkg/query-service/auth/rbac.go @@ -3,6 +3,7 @@ package auth import ( "context" "net/http" + "strings" "github.com/pkg/errors" "go.signoz.io/signoz/pkg/query-service/constants" @@ -48,6 +49,19 @@ func InitAuthCache(ctx context.Context) error { return nil } +func GetAuthorizationToken(r *http.Request) string { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "" + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "" + } + return authHeaderParts[1] +} + func GetUserFromRequest(r *http.Request) (*model.UserPayload, error) { accessJwt, err := ExtractJwtFromRequest(r) if err != nil { From b0f62daa2427a4485b0913a54d6f34e66c641ae3 Mon Sep 17 00:00:00 2001 From: Ahsan Barkati Date: Wed, 15 Feb 2023 23:59:03 +0530 Subject: [PATCH 4/6] Cleanup rbac.go --- pkg/query-service/auth/rbac.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/pkg/query-service/auth/rbac.go b/pkg/query-service/auth/rbac.go index 38ae283bb2..44f65576ed 100644 --- a/pkg/query-service/auth/rbac.go +++ b/pkg/query-service/auth/rbac.go @@ -3,7 +3,6 @@ package auth import ( "context" "net/http" - "strings" "github.com/pkg/errors" "go.signoz.io/signoz/pkg/query-service/constants" @@ -49,19 +48,6 @@ func InitAuthCache(ctx context.Context) error { return nil } -func GetAuthorizationToken(r *http.Request) string { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "" - } - - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "" - } - return authHeaderParts[1] -} - func GetUserFromRequest(r *http.Request) (*model.UserPayload, error) { accessJwt, err := ExtractJwtFromRequest(r) if err != nil { From df7f276f0306e652c953ad14ac6fedbf04424b08 Mon Sep 17 00:00:00 2001 From: Ahsan Barkati Date: Tue, 21 Feb 2023 22:37:02 +0530 Subject: [PATCH 5/6] Change header name --- ee/query-service/app/server.go | 8 ++++---- pkg/query-service/app/auth.go | 4 ---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index a634057531..42a8f2496e 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -203,14 +203,14 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, } func getPATToken(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { + patHeader := r.Header.Get("SIGNOZ-API-KEY") + if patHeader == "" { return "", nil } - authHeaderParts := strings.Fields(authHeader) + authHeaderParts := strings.Fields(patHeader) if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", fmt.Errorf("authorization header format must be Bearer {token}") + return "", fmt.Errorf("PAT authorization header format must be bearer {token}") } return authHeaderParts[1], nil diff --git a/pkg/query-service/app/auth.go b/pkg/query-service/app/auth.go index ef6541a4f6..dccf6dd8dd 100644 --- a/pkg/query-service/app/auth.go +++ b/pkg/query-service/app/auth.go @@ -19,10 +19,6 @@ func NewAuthMiddleware(f func(r *http.Request) (*model.UserPayload, error)) *Aut } } -// func (am *AuthMiddleware) GetUserFromRequest(r *http.Request) (*model.UserPayload, error) { -// return auth.GetUserFromRequest(r) -// } - func (am *AuthMiddleware) OpenAccess(f func(http.ResponseWriter, *http.Request)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { f(w, r) From eb2fe200251e4222e69f2baafeb9be05db1da757 Mon Sep 17 00:00:00 2001 From: Ahsan Barkati Date: Fri, 24 Feb 2023 01:19:01 +0530 Subject: [PATCH 6/6] Address review comments --- ee/query-service/app/api/pat.go | 12 +++--- ee/query-service/app/server.go | 56 +++++++------------------ ee/query-service/dao/interface.go | 2 + ee/query-service/dao/sqlite/modelDao.go | 2 +- ee/query-service/dao/sqlite/pat.go | 45 ++++++++++++++++++++ ee/query-service/model/pat.go | 2 +- 6 files changed, 71 insertions(+), 48 deletions(-) diff --git a/ee/query-service/app/api/pat.go b/ee/query-service/app/api/pat.go index d708ba0606..619c875c8f 100644 --- a/ee/query-service/app/api/pat.go +++ b/ee/query-service/app/api/pat.go @@ -41,11 +41,13 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) { return } + // All the PATs are associated with the user creating the PAT. Hence, the permissions + // associated with the PAT is also equivalent to that of the user. req.UserID = user.Id req.CreatedAt = time.Now().Unix() req.Token = generatePATToken() - zap.S().Infof("Got PAT request: %+v", req) + zap.S().Debugf("Got PAT request: %+v", req) if apierr := ah.AppDao().CreatePAT(ctx, &req); apierr != nil { RespondError(w, apierr, nil) return @@ -70,7 +72,7 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) { RespondError(w, apierr, nil) return } - ah.WriteJSON(w, r, pats) + ah.Respond(w, pats) } func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) { @@ -84,7 +86,7 @@ func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) { }, nil) return } - pat, apierr := ah.AppDao().GetPAT(ctx, id) + pat, apierr := ah.AppDao().GetPATByID(ctx, id) if apierr != nil { RespondError(w, apierr, nil) return @@ -96,10 +98,10 @@ func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) { }, nil) return } - zap.S().Infof("Delete PAT with id: %+v", id) + zap.S().Debugf("Delete PAT with id: %+v", id) if apierr := ah.AppDao().DeletePAT(ctx, id); apierr != nil { RespondError(w, apierr, nil) return } - ah.WriteJSON(w, r, map[string]string{"data": "pat deleted successfully"}) + ah.Respond(w, map[string]string{"data": "pat deleted successfully"}) } diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index 42a8f2496e..7a0d794a88 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -10,7 +10,6 @@ import ( "net/http" _ "net/http/pprof" // http profiler "os" - "strings" "time" "github.com/gorilla/handlers" @@ -191,7 +190,7 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, // ip here for alert manager AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "SIGNOZ-API-KEY"}, }) handler := c.Handler(r) @@ -202,50 +201,25 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, }, nil } -func getPATToken(r *http.Request) (string, error) { - patHeader := r.Header.Get("SIGNOZ-API-KEY") - if patHeader == "" { - return "", nil - } - - authHeaderParts := strings.Fields(patHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", fmt.Errorf("PAT authorization header format must be bearer {token}") - } - - return authHeaderParts[1], nil -} - func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, error) { r := mux.NewRouter() - getUserFromPAT := func(r *http.Request) (*model.UserPayload, error) { - patToken, err := getPATToken(r) - if err != nil { - return nil, fmt.Errorf("failed to get PAT token in request headers, err: %v", err) - } - ctx := context.Background() - dao := apiHandler.AppDao() - pat, err := dao.GetPAT(ctx, patToken) - if err != nil { - return nil, fmt.Errorf("failed to fetch PAT token from DB, err %v", err) - } - user, apierr := dao.GetUser(ctx, pat.UserID) - if apierr != nil { - return nil, fmt.Errorf("failed to fetch user for PAT from DB, err: %v", apierr) - } - return user, nil - } - getUserFromRequest := func(r *http.Request) (*model.UserPayload, error) { - user, err := getUserFromPAT(r) - if err == nil && user != nil { - zap.S().Debugf("Found valid PAT user: %+v", user) - return user, nil - } - if err != nil { - zap.S().Debugf("Error while getting user for PAT: %+v", err) + patToken := r.Header.Get("SIGNOZ-API-KEY") + if len(patToken) > 0 { + zap.S().Debugf("Received a non-zero length PAT token") + ctx := context.Background() + dao := apiHandler.AppDao() + + user, err := dao.GetUserByPAT(ctx, patToken) + if err == nil && user != nil { + zap.S().Debugf("Found valid PAT user: %+v", user) + return user, nil + } + if err != nil { + zap.S().Debugf("Error while getting user for PAT: %+v", err) + } } return baseauth.GetUserFromRequest(r) } diff --git a/ee/query-service/dao/interface.go b/ee/query-service/dao/interface.go index da6963f68c..2303bb72d4 100644 --- a/ee/query-service/dao/interface.go +++ b/ee/query-service/dao/interface.go @@ -36,6 +36,8 @@ type ModelDao interface { CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError) + GetPATByID(ctx context.Context, id string) (*model.PAT, basemodel.BaseApiError) + GetUserByPAT(ctx context.Context, token string) (*basemodel.UserPayload, basemodel.BaseApiError) ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError) DeletePAT(ctx context.Context, id string) basemodel.BaseApiError } diff --git a/ee/query-service/dao/sqlite/modelDao.go b/ee/query-service/dao/sqlite/modelDao.go index 9b1d74c034..3c195ea9bf 100644 --- a/ee/query-service/dao/sqlite/modelDao.go +++ b/ee/query-service/dao/sqlite/modelDao.go @@ -52,7 +52,7 @@ func InitDB(dataSourceName string) (*modelDao, error) { CREATE TABLE IF NOT EXISTS personal_access_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT NOT NULL, - token TEXT NOT NULL, + token TEXT NOT NULL UNIQUE, name TEXT NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, diff --git a/ee/query-service/dao/sqlite/pat.go b/ee/query-service/dao/sqlite/pat.go index 7b3569db66..cc4de546c5 100644 --- a/ee/query-service/dao/sqlite/pat.go +++ b/ee/query-service/dao/sqlite/pat.go @@ -59,3 +59,48 @@ func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemo return &pats[0], nil } + +func (m *modelDao) GetPATByID(ctx context.Context, id string) (*model.PAT, basemodel.BaseApiError) { + pats := []model.PAT{} + + if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE id=?;`, id); err != nil { + return nil, model.InternalError(fmt.Errorf("failed to fetch PAT")) + } + + if len(pats) != 1 { + return nil, &model.ApiError{ + Typ: model.ErrorInternal, + Err: fmt.Errorf("found zero or multiple PATs with same token"), + } + } + + return &pats[0], nil +} + +func (m *modelDao) GetUserByPAT(ctx context.Context, token string) (*basemodel.UserPayload, basemodel.BaseApiError) { + users := []basemodel.UserPayload{} + + query := `SELECT + u.id, + u.name, + u.email, + u.password, + u.created_at, + u.profile_picture_url, + u.org_id, + u.group_id + FROM users u, personal_access_tokens p + WHERE u.id = p.user_id and p.token=?;` + + if err := m.DB().Select(&users, query, token); err != nil { + return nil, model.InternalError(fmt.Errorf("failed to fetch user from PAT, err: %v", err)) + } + + if len(users) != 1 { + return nil, &model.ApiError{ + Typ: model.ErrorInternal, + Err: fmt.Errorf("found zero or multiple users with same PAT token"), + } + } + return &users[0], nil +} diff --git a/ee/query-service/model/pat.go b/ee/query-service/model/pat.go index f320d0be7c..c22282060b 100644 --- a/ee/query-service/model/pat.go +++ b/ee/query-service/model/pat.go @@ -6,5 +6,5 @@ type PAT struct { Token string `json:"token" db:"token"` Name string `json:"name" db:"name"` CreatedAt int64 `json:"createdAt" db:"created_at"` - ExpiresAt int64 `json:"expiresAt" db:"expires_at"` + ExpiresAt int64 `json:"expiresAt" db:"expires_at"` // unused as of now }