diff --git a/ee/query-service/app/api/auth.go b/ee/query-service/app/api/auth.go index 8d96320778..e013b87b29 100644 --- a/ee/query-service/app/api/auth.go +++ b/ee/query-service/app/api/auth.go @@ -113,7 +113,7 @@ func (ah *APIHandler) registerUser(w http.ResponseWriter, r *http.Request) { } if domain != nil && domain.SsoEnabled { - // so is enabled, create user and respond precheck data + // sso is enabled, create user and respond precheck data user, apierr := baseauth.RegisterInvitedUser(ctx, req, true) if apierr != nil { RespondError(w, apierr, nil) diff --git a/ee/query-service/dao/sqlite/auth.go b/ee/query-service/dao/sqlite/auth.go index 6eb46d2ea6..e06c073997 100644 --- a/ee/query-service/dao/sqlite/auth.go +++ b/ee/query-service/dao/sqlite/auth.go @@ -5,16 +5,61 @@ import ( "fmt" "net/url" "strings" + "time" + "github.com/google/uuid" "go.signoz.io/signoz/ee/query-service/constants" "go.signoz.io/signoz/ee/query-service/model" + baseauth "go.signoz.io/signoz/pkg/query-service/auth" baseconst "go.signoz.io/signoz/pkg/query-service/constants" basemodel "go.signoz.io/signoz/pkg/query-service/model" - baseauth "go.signoz.io/signoz/pkg/query-service/auth" + "go.signoz.io/signoz/pkg/query-service/utils" "go.uber.org/zap" ) -// PrepareSsoRedirect prepares redirect page link after SSO response +func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (*basemodel.User, basemodel.BaseApiError) { + // get auth domain from email domain + domain, apierr := m.GetDomainByEmail(ctx, email) + + if apierr != nil { + zap.S().Errorf("failed to get domain from email", apierr) + return nil, model.InternalErrorStr("failed to get domain from email") + } + + hash, err := baseauth.PasswordHash(utils.GeneratePassowrd()) + if err != nil { + zap.S().Errorf("failed to generate password hash when registering a user via SSO redirect", zap.Error(err)) + return nil, model.InternalErrorStr("failed to generate password hash") + } + + group, apiErr := m.GetGroupByName(ctx, baseconst.ViewerGroup) + if apiErr != nil { + zap.S().Debugf("GetGroupByName failed, err: %v\n", apiErr.Err) + return nil, apiErr + } + + user := &basemodel.User{ + Id: uuid.NewString(), + Name: "", + Email: email, + Password: hash, + CreatedAt: time.Now().Unix(), + ProfilePictureURL: "", // Currently unused + GroupId: group.Id, + OrgId: domain.OrgId, + } + + user, apiErr = m.CreateUser(ctx, user, false) + if apiErr != nil { + zap.S().Debugf("CreateUser failed, err: %v\n", apiErr.Err) + return nil, apiErr + } + + return user, nil + +} + +// PrepareSsoRedirect prepares redirect page link after SSO response // is successfully parsed (i.e. valid email is available) func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email string) (redirectURL string, apierr basemodel.BaseApiError) { @@ -24,7 +69,20 @@ func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email st return "", model.BadRequestStr("invalid user email received from the auth provider") } - tokenStore, err := baseauth.GenerateJWTForUser(&userPayload.User) + user := &basemodel.User{} + + if userPayload == nil { + newUser, apiErr := m.createUserForSAMLRequest(ctx, email) + user = newUser + if apiErr != nil { + zap.S().Errorf("failed to create user with email received from auth provider: %v", apierr.Error()) + return "", apiErr + } + } else { + user = &userPayload.User + } + + tokenStore, err := baseauth.GenerateJWTForUser(user) if err != nil { zap.S().Errorf("failed to generate token for SSO login user", err) return "", model.InternalErrorStr("failed to generate token for the user") @@ -33,7 +91,7 @@ func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email st return fmt.Sprintf("%s?jwt=%s&usr=%s&refreshjwt=%s", redirectUri, tokenStore.AccessJwt, - userPayload.User.Id, + user.Id, tokenStore.RefreshJwt), nil } @@ -76,6 +134,7 @@ func (m *modelDao) PrecheckLogin(ctx context.Context, email, sourceUrl string) ( if userPayload == nil { resp.IsUser = false } + ssoAvailable := true err := m.checkFeature(model.SSO) if err != nil { @@ -91,6 +150,8 @@ func (m *modelDao) PrecheckLogin(ctx context.Context, email, sourceUrl string) ( if ssoAvailable { + resp.IsUser = true + // find domain from email orgDomain, apierr := m.GetDomainByEmail(ctx, email) if apierr != nil { diff --git a/ee/query-service/dao/sqlite/domain.go b/ee/query-service/dao/sqlite/domain.go index d1ef8aa8d2..9fbee9e9df 100644 --- a/ee/query-service/dao/sqlite/domain.go +++ b/ee/query-service/dao/sqlite/domain.go @@ -4,8 +4,8 @@ import ( "context" "database/sql" "encoding/json" - "net/url" "fmt" + "net/url" "strings" "time" @@ -28,29 +28,70 @@ type StoredDomain struct { // GetDomainFromSsoResponse uses relay state received from IdP to fetch // user domain. The domain is further used to process validity of the response. -// when sending login request to IdP we send relay state as URL (site url) -// with domainId as query parameter. +// when sending login request to IdP we send relay state as URL (site url) +// with domainId or domainName as query parameter. func (m *modelDao) GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*model.OrgDomain, error) { // derive domain id from relay state now - var domainIdStr string + var domainIdStr string + var domainNameStr string + var domain *model.OrgDomain + for k, v := range relayState.Query() { if k == "domainId" && len(v) > 0 { domainIdStr = strings.Replace(v[0], ":", "-", -1) } + if k == "domainName" && len(v) > 0 { + domainNameStr = v[0] + } } - domainId, err := uuid.Parse(domainIdStr) + if domainIdStr != "" { + domainId, err := uuid.Parse(domainIdStr) + if err != nil { + zap.S().Errorf("failed to parse domainId from relay state", err) + return nil, fmt.Errorf("failed to parse domainId from IdP response") + } + + domain, err = m.GetDomain(ctx, domainId) + if (err != nil) || domain == nil { + zap.S().Errorf("failed to find domain from domainId received in IdP response", err.Error()) + return nil, fmt.Errorf("invalid credentials") + } + } + + if domainNameStr != "" { + + domainFromDB, err := m.GetDomainByName(ctx, domainNameStr) + domain = domainFromDB + if (err != nil) || domain == nil { + zap.S().Errorf("failed to find domain from domainName received in IdP response", err.Error()) + return nil, fmt.Errorf("invalid credentials") + } + } + if domain != nil { + return domain, nil + } + + return nil, fmt.Errorf("failed to find domain received in IdP response") +} + +// GetDomainByName returns org domain for a given domain name +func (m *modelDao) GetDomainByName(ctx context.Context, name string) (*model.OrgDomain, basemodel.BaseApiError) { + + stored := StoredDomain{} + err := m.DB().Get(&stored, `SELECT * FROM org_domains WHERE name=$1 LIMIT 1`, name) + if err != nil { - zap.S().Errorf("failed to parse domain id from relay state", err) - return nil, fmt.Errorf("failed to parse response from IdP response") + if err == sql.ErrNoRows { + return nil, model.BadRequest(fmt.Errorf("invalid domain name")) + } + return nil, model.InternalError(err) } - domain, err := m.GetDomain(ctx, domainId) - if (err != nil) || domain == nil { - zap.S().Errorf("failed to find domain received in IdP response", err.Error()) - return nil, fmt.Errorf("invalid credentials") + domain := &model.OrgDomain{Id: stored.Id, Name: stored.Name, OrgId: stored.OrgId} + if err := domain.LoadConfig(stored.Data); err != nil { + return domain, model.InternalError(err) } - return domain, nil } diff --git a/ee/query-service/model/errors.go b/ee/query-service/model/errors.go index 6820cf8d44..7e7b8410e2 100644 --- a/ee/query-service/model/errors.go +++ b/ee/query-service/model/errors.go @@ -2,6 +2,7 @@ package model import ( "fmt" + basemodel "go.signoz.io/signoz/pkg/query-service/model" ) @@ -61,7 +62,6 @@ func InternalError(err error) *ApiError { } } - // InternalErrorStr returns a ApiError object of internal type for string input func InternalErrorStr(s string) *ApiError { return &ApiError{ @@ -69,6 +69,7 @@ func InternalErrorStr(s string) *ApiError { Err: fmt.Errorf(s), } } + var ( ErrorNone basemodel.ErrorType = "" ErrorTimeout basemodel.ErrorType = "timeout" diff --git a/pkg/query-service/auth/auth.go b/pkg/query-service/auth/auth.go index d2488a1399..7f78fa3660 100644 --- a/pkg/query-service/auth/auth.go +++ b/pkg/query-service/auth/auth.go @@ -165,7 +165,7 @@ func ResetPassword(ctx context.Context, req *model.ResetPasswordRequest) error { return errors.New("Invalid reset password request") } - hash, err := passwordHash(req.Password) + hash, err := PasswordHash(req.Password) if err != nil { return errors.Wrap(err, "Failed to generate password hash") } @@ -192,7 +192,7 @@ func ChangePassword(ctx context.Context, req *model.ChangePasswordRequest) error return ErrorInvalidCreds } - hash, err := passwordHash(req.NewPassword) + hash, err := PasswordHash(req.NewPassword) if err != nil { return errors.Wrap(err, "Failed to generate password hash") } @@ -243,7 +243,7 @@ func RegisterFirstUser(ctx context.Context, req *RegisterRequest) (*model.User, var hash string var err error - hash, err = passwordHash(req.Password) + hash, err = PasswordHash(req.Password) if err != nil { zap.S().Errorf("failed to generate password hash when registering a user", zap.Error(err)) return nil, model.InternalError(model.ErrSignupFailed{}) @@ -314,13 +314,13 @@ func RegisterInvitedUser(ctx context.Context, req *RegisterRequest, nopassword b // check if password is not empty, as for SSO case it can be if req.Password != "" { - hash, err = passwordHash(req.Password) + hash, err = PasswordHash(req.Password) if err != nil { zap.S().Errorf("failed to generate password hash when registering a user", zap.Error(err)) return nil, model.InternalError(model.ErrSignupFailed{}) } } else { - hash, err = passwordHash(utils.GeneratePassowrd()) + hash, err = PasswordHash(utils.GeneratePassowrd()) if err != nil { zap.S().Errorf("failed to generate password hash when registering a user", zap.Error(err)) return nil, model.InternalError(model.ErrSignupFailed{}) @@ -419,7 +419,7 @@ func authenticateLogin(ctx context.Context, req *model.LoginRequest) (*model.Use } // Generate hash from the password. -func passwordHash(pass string) (string, error) { +func PasswordHash(pass string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost) if err != nil { return "", err