Address review comments

This commit is contained in:
Ahsan Barkati 2023-02-24 01:19:01 +05:30
parent df7f276f03
commit eb2fe20025
6 changed files with 71 additions and 48 deletions

View File

@ -41,11 +41,13 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
return return
} }
// All the PATs are associated with the user creating the PAT. Hence, the permissions
// associated with the PAT is also equivalent to that of the user.
req.UserID = user.Id req.UserID = user.Id
req.CreatedAt = time.Now().Unix() req.CreatedAt = time.Now().Unix()
req.Token = generatePATToken() req.Token = generatePATToken()
zap.S().Infof("Got PAT request: %+v", req) zap.S().Debugf("Got PAT request: %+v", req)
if apierr := ah.AppDao().CreatePAT(ctx, &req); apierr != nil { if apierr := ah.AppDao().CreatePAT(ctx, &req); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
@ -70,7 +72,7 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
ah.WriteJSON(w, r, pats) ah.Respond(w, pats)
} }
func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) { func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) {
@ -84,7 +86,7 @@ func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) {
}, nil) }, nil)
return return
} }
pat, apierr := ah.AppDao().GetPAT(ctx, id) pat, apierr := ah.AppDao().GetPATByID(ctx, id)
if apierr != nil { if apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
@ -96,10 +98,10 @@ func (ah *APIHandler) deletePAT(w http.ResponseWriter, r *http.Request) {
}, nil) }, nil)
return return
} }
zap.S().Infof("Delete PAT with id: %+v", id) zap.S().Debugf("Delete PAT with id: %+v", id)
if apierr := ah.AppDao().DeletePAT(ctx, id); apierr != nil { if apierr := ah.AppDao().DeletePAT(ctx, id); apierr != nil {
RespondError(w, apierr, nil) RespondError(w, apierr, nil)
return return
} }
ah.WriteJSON(w, r, map[string]string{"data": "pat deleted successfully"}) ah.Respond(w, map[string]string{"data": "pat deleted successfully"})
} }

View File

@ -10,7 +10,6 @@ import (
"net/http" "net/http"
_ "net/http/pprof" // http profiler _ "net/http/pprof" // http profiler
"os" "os"
"strings"
"time" "time"
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
@ -191,7 +190,7 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,
// ip here for alert manager // ip here for alert manager
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH"}, AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "SIGNOZ-API-KEY"},
}) })
handler := c.Handler(r) handler := c.Handler(r)
@ -202,44 +201,18 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,
}, nil }, nil
} }
func getPATToken(r *http.Request) (string, error) {
patHeader := r.Header.Get("SIGNOZ-API-KEY")
if patHeader == "" {
return "", nil
}
authHeaderParts := strings.Fields(patHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", fmt.Errorf("PAT authorization header format must be bearer {token}")
}
return authHeaderParts[1], nil
}
func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, error) { func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, error) {
r := mux.NewRouter() r := mux.NewRouter()
getUserFromPAT := func(r *http.Request) (*model.UserPayload, error) { getUserFromRequest := func(r *http.Request) (*model.UserPayload, error) {
patToken, err := getPATToken(r) patToken := r.Header.Get("SIGNOZ-API-KEY")
if err != nil { if len(patToken) > 0 {
return nil, fmt.Errorf("failed to get PAT token in request headers, err: %v", err) zap.S().Debugf("Received a non-zero length PAT token")
}
ctx := context.Background() ctx := context.Background()
dao := apiHandler.AppDao() dao := apiHandler.AppDao()
pat, err := dao.GetPAT(ctx, patToken)
if err != nil {
return nil, fmt.Errorf("failed to fetch PAT token from DB, err %v", err)
}
user, apierr := dao.GetUser(ctx, pat.UserID)
if apierr != nil {
return nil, fmt.Errorf("failed to fetch user for PAT from DB, err: %v", apierr)
}
return user, nil
}
getUserFromRequest := func(r *http.Request) (*model.UserPayload, error) { user, err := dao.GetUserByPAT(ctx, patToken)
user, err := getUserFromPAT(r)
if err == nil && user != nil { if err == nil && user != nil {
zap.S().Debugf("Found valid PAT user: %+v", user) zap.S().Debugf("Found valid PAT user: %+v", user)
return user, nil return user, nil
@ -247,6 +220,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, e
if err != nil { if err != nil {
zap.S().Debugf("Error while getting user for PAT: %+v", err) zap.S().Debugf("Error while getting user for PAT: %+v", err)
} }
}
return baseauth.GetUserFromRequest(r) return baseauth.GetUserFromRequest(r)
} }
am := baseapp.NewAuthMiddleware(getUserFromRequest) am := baseapp.NewAuthMiddleware(getUserFromRequest)

View File

@ -36,6 +36,8 @@ type ModelDao interface {
CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError CreatePAT(ctx context.Context, p *model.PAT) basemodel.BaseApiError
GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError) GetPAT(ctx context.Context, pat string) (*model.PAT, basemodel.BaseApiError)
GetPATByID(ctx context.Context, id string) (*model.PAT, basemodel.BaseApiError)
GetUserByPAT(ctx context.Context, token string) (*basemodel.UserPayload, basemodel.BaseApiError)
ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError) ListPATs(ctx context.Context, userID string) ([]model.PAT, basemodel.BaseApiError)
DeletePAT(ctx context.Context, id string) basemodel.BaseApiError DeletePAT(ctx context.Context, id string) basemodel.BaseApiError
} }

View File

@ -52,7 +52,7 @@ func InitDB(dataSourceName string) (*modelDao, error) {
CREATE TABLE IF NOT EXISTS personal_access_tokens ( CREATE TABLE IF NOT EXISTS personal_access_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
token TEXT NOT NULL, token TEXT NOT NULL UNIQUE,
name TEXT NOT NULL, name TEXT NOT NULL,
created_at INTEGER NOT NULL, created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL, expires_at INTEGER NOT NULL,

View File

@ -59,3 +59,48 @@ func (m *modelDao) GetPAT(ctx context.Context, token string) (*model.PAT, basemo
return &pats[0], nil return &pats[0], nil
} }
func (m *modelDao) GetPATByID(ctx context.Context, id string) (*model.PAT, basemodel.BaseApiError) {
pats := []model.PAT{}
if err := m.DB().Select(&pats, `SELECT * FROM personal_access_tokens WHERE id=?;`, id); err != nil {
return nil, model.InternalError(fmt.Errorf("failed to fetch PAT"))
}
if len(pats) != 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: fmt.Errorf("found zero or multiple PATs with same token"),
}
}
return &pats[0], nil
}
func (m *modelDao) GetUserByPAT(ctx context.Context, token string) (*basemodel.UserPayload, basemodel.BaseApiError) {
users := []basemodel.UserPayload{}
query := `SELECT
u.id,
u.name,
u.email,
u.password,
u.created_at,
u.profile_picture_url,
u.org_id,
u.group_id
FROM users u, personal_access_tokens p
WHERE u.id = p.user_id and p.token=?;`
if err := m.DB().Select(&users, query, token); err != nil {
return nil, model.InternalError(fmt.Errorf("failed to fetch user from PAT, err: %v", err))
}
if len(users) != 1 {
return nil, &model.ApiError{
Typ: model.ErrorInternal,
Err: fmt.Errorf("found zero or multiple users with same PAT token"),
}
}
return &users[0], nil
}

View File

@ -6,5 +6,5 @@ type PAT struct {
Token string `json:"token" db:"token"` Token string `json:"token" db:"token"`
Name string `json:"name" db:"name"` Name string `json:"name" db:"name"`
CreatedAt int64 `json:"createdAt" db:"created_at"` CreatedAt int64 `json:"createdAt" db:"created_at"`
ExpiresAt int64 `json:"expiresAt" db:"expires_at"` ExpiresAt int64 `json:"expiresAt" db:"expires_at"` // unused as of now
} }