diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index 464517ef1a..348cdbddd2 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -1,6 +1,7 @@ package app import ( + "bufio" "bytes" "context" "encoding/json" @@ -317,7 +318,7 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server, // ip here for alert manager AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "SIGNOZ-API-KEY"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "SIGNOZ-API-KEY", "X-SIGNOZ-QUERY-ID", "Sec-WebSocket-Protocol"}, }) handler := c.Handler(r) @@ -362,7 +363,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, e c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "cache-control"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "cache-control", "X-SIGNOZ-QUERY-ID", "Sec-WebSocket-Protocol"}, }) handler := c.Handler(r) @@ -418,6 +419,15 @@ func (lrw *loggingResponseWriter) Flush() { lrw.ResponseWriter.(http.Flusher).Flush() } +// Support websockets +func (lrw *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := lrw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("hijack not supported") + } + return h.Hijack() +} + func extractQueryRangeData(path string, r *http.Request) (map[string]interface{}, bool) { pathToExtractBodyFromV3 := "/api/v3/query_range" pathToExtractBodyFromV4 := "/api/v4/query_range" diff --git a/pkg/query-service/app/clickhouseReader/query_progress/inmemory_tracker.go b/pkg/query-service/app/clickhouseReader/query_progress/inmemory_tracker.go new file mode 100644 index 0000000000..d29a61cb5c --- /dev/null +++ b/pkg/query-service/app/clickhouseReader/query_progress/inmemory_tracker.go @@ -0,0 +1,255 @@ +package queryprogress + +import ( + "fmt" + "sync" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/google/uuid" + "go.signoz.io/signoz/pkg/query-service/model" + v3 "go.signoz.io/signoz/pkg/query-service/model/v3" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +// tracks progress and manages subscriptions for all queries +type inMemoryQueryProgressTracker struct { + queries map[string]*queryTracker + lock sync.RWMutex +} + +func (tracker *inMemoryQueryProgressTracker) ReportQueryStarted( + queryId string, +) (postQueryCleanup func(), err *model.ApiError) { + tracker.lock.Lock() + defer tracker.lock.Unlock() + + _, exists := tracker.queries[queryId] + if exists { + return nil, model.BadRequest(fmt.Errorf( + "query %s already started", queryId, + )) + } + + tracker.queries[queryId] = newQueryTracker(queryId) + + return func() { + tracker.onQueryFinished(queryId) + }, nil +} + +func (tracker *inMemoryQueryProgressTracker) ReportQueryProgress( + queryId string, chProgress *clickhouse.Progress, +) *model.ApiError { + queryTracker, err := tracker.getQueryTracker(queryId) + if err != nil { + return err + } + + queryTracker.handleProgressUpdate(chProgress) + return nil +} + +func (tracker *inMemoryQueryProgressTracker) SubscribeToQueryProgress( + queryId string, +) (<-chan v3.QueryProgress, func(), *model.ApiError) { + queryTracker, err := tracker.getQueryTracker(queryId) + if err != nil { + return nil, nil, err + } + + return queryTracker.subscribe() +} + +func (tracker *inMemoryQueryProgressTracker) onQueryFinished( + queryId string, +) { + tracker.lock.Lock() + queryTracker := tracker.queries[queryId] + if queryTracker != nil { + delete(tracker.queries, queryId) + } + tracker.lock.Unlock() + + if queryTracker != nil { + queryTracker.onFinished() + } +} + +func (tracker *inMemoryQueryProgressTracker) getQueryTracker( + queryId string, +) (*queryTracker, *model.ApiError) { + tracker.lock.RLock() + defer tracker.lock.RUnlock() + + queryTracker := tracker.queries[queryId] + if queryTracker == nil { + return nil, model.NotFoundError(fmt.Errorf( + "query %s doesn't exist", queryId, + )) + } + + return queryTracker, nil +} + +// Tracks progress and manages subscriptions for a single query +type queryTracker struct { + queryId string + isFinished bool + + progress *v3.QueryProgress + subscriptions map[string]*queryProgressSubscription + + lock sync.Mutex +} + +func newQueryTracker(queryId string) *queryTracker { + return &queryTracker{ + queryId: queryId, + subscriptions: map[string]*queryProgressSubscription{}, + } +} + +func (qt *queryTracker) handleProgressUpdate(p *clickhouse.Progress) { + qt.lock.Lock() + defer qt.lock.Unlock() + + if qt.isFinished { + zap.L().Warn( + "received clickhouse progress update for finished query", + zap.String("queryId", qt.queryId), zap.Any("progress", p), + ) + return + } + + if qt.progress == nil { + // This is the first update + qt.progress = &v3.QueryProgress{} + } + updateQueryProgress(qt.progress, p) + + // broadcast latest state to all subscribers. + for _, sub := range maps.Values(qt.subscriptions) { + sub.send(*qt.progress) + } +} + +func (qt *queryTracker) subscribe() ( + <-chan v3.QueryProgress, func(), *model.ApiError, +) { + qt.lock.Lock() + defer qt.lock.Unlock() + + if qt.isFinished { + return nil, nil, model.NotFoundError(fmt.Errorf( + "query %s already finished", qt.queryId, + )) + } + + subscriberId := uuid.NewString() + subscription := newQueryProgressSubscription() + qt.subscriptions[subscriberId] = subscription + + if qt.progress != nil { + subscription.send(*qt.progress) + } + + return subscription.ch, func() { + qt.unsubscribe(subscriberId) + }, nil +} + +func (qt *queryTracker) unsubscribe(subscriberId string) { + qt.lock.Lock() + defer qt.lock.Unlock() + + if qt.isFinished { + zap.L().Debug( + "received unsubscribe request after query finished", + zap.String("subscriber", subscriberId), + zap.String("queryId", qt.queryId), + ) + return + } + + subscription := qt.subscriptions[subscriberId] + if subscription != nil { + subscription.close() + delete(qt.subscriptions, subscriberId) + } +} + +func (qt *queryTracker) onFinished() { + qt.lock.Lock() + defer qt.lock.Unlock() + + if qt.isFinished { + zap.L().Warn( + "receiver query finish report after query finished", + zap.String("queryId", qt.queryId), + ) + return + } + + for subId, sub := range qt.subscriptions { + sub.close() + delete(qt.subscriptions, subId) + } + + qt.isFinished = true +} + +type queryProgressSubscription struct { + ch chan v3.QueryProgress + isClosed bool + lock sync.Mutex +} + +func newQueryProgressSubscription() *queryProgressSubscription { + ch := make(chan v3.QueryProgress, 1000) + return &queryProgressSubscription{ + ch: ch, + } +} + +// Must not block or panic in any scenario +func (ch *queryProgressSubscription) send(progress v3.QueryProgress) { + ch.lock.Lock() + defer ch.lock.Unlock() + + if ch.isClosed { + zap.L().Error( + "can't send query progress: channel already closed.", + zap.Any("progress", progress), + ) + return + } + + // subscription channels are expected to have big enough buffers to ensure + // blocking while sending doesn't happen in the happy path + select { + case ch.ch <- progress: + zap.L().Debug("published query progress", zap.Any("progress", progress)) + default: + zap.L().Error( + "couldn't publish query progress. dropping update.", + zap.Any("progress", progress), + ) + } +} + +func (ch *queryProgressSubscription) close() { + ch.lock.Lock() + defer ch.lock.Unlock() + + if !ch.isClosed { + close(ch.ch) + ch.isClosed = true + } +} + +func updateQueryProgress(qp *v3.QueryProgress, chProgress *clickhouse.Progress) { + qp.ReadRows += chProgress.Rows + qp.ReadBytes += chProgress.Bytes + qp.ElapsedMs += uint64(chProgress.Elapsed.Milliseconds()) +} diff --git a/pkg/query-service/app/clickhouseReader/query_progress/tracker.go b/pkg/query-service/app/clickhouseReader/query_progress/tracker.go new file mode 100644 index 0000000000..d424c99c57 --- /dev/null +++ b/pkg/query-service/app/clickhouseReader/query_progress/tracker.go @@ -0,0 +1,31 @@ +package queryprogress + +import ( + "github.com/ClickHouse/clickhouse-go/v2" + "go.signoz.io/signoz/pkg/query-service/model" + v3 "go.signoz.io/signoz/pkg/query-service/model/v3" +) + +type QueryProgressTracker interface { + // Tells the tracker that query with id `queryId` has started. + // Progress can only be reported for and tracked for a query that is in progress. + // Returns a cleanup function that must be called after the query finishes. + ReportQueryStarted(queryId string) (postQueryCleanup func(), err *model.ApiError) + + // Report progress stats received from clickhouse for `queryId` + ReportQueryProgress(queryId string, chProgress *clickhouse.Progress) *model.ApiError + + // Subscribe to progress updates for `queryId` + // The returned channel will produce `QueryProgress` instances representing + // the latest state of query progress stats. Also returns a function that + // can be called to unsubscribe before the query finishes, if needed. + SubscribeToQueryProgress(queryId string) (ch <-chan v3.QueryProgress, unsubscribe func(), err *model.ApiError) +} + +func NewQueryProgressTracker() QueryProgressTracker { + // InMemory tracker is useful only for single replica query service setups. + // Multi replica setups must use a centralized store for tracking and subscribing to query progress + return &inMemoryQueryProgressTracker{ + queries: map[string]*queryTracker{}, + } +} diff --git a/pkg/query-service/app/clickhouseReader/query_progress/tracker_test.go b/pkg/query-service/app/clickhouseReader/query_progress/tracker_test.go new file mode 100644 index 0000000000..4babe47f82 --- /dev/null +++ b/pkg/query-service/app/clickhouseReader/query_progress/tracker_test.go @@ -0,0 +1,102 @@ +package queryprogress + +import ( + "testing" + "time" + + "github.com/ClickHouse/clickhouse-go/v2" + "github.com/stretchr/testify/require" + "go.signoz.io/signoz/pkg/query-service/model" + v3 "go.signoz.io/signoz/pkg/query-service/model/v3" +) + +func TestQueryProgressTracking(t *testing.T) { + require := require.New(t) + + tracker := NewQueryProgressTracker() + + testQueryId := "test-query" + + testProgress := &clickhouse.Progress{} + err := tracker.ReportQueryProgress(testQueryId, testProgress) + require.NotNil(err, "shouldn't be able to report query progress before query has been started") + require.Equal(err.Type(), model.ErrorNotFound) + + ch, unsubscribe, err := tracker.SubscribeToQueryProgress(testQueryId) + require.NotNil(err, "shouldn't be able to subscribe for progress updates before query has been started") + require.Equal(err.Type(), model.ErrorNotFound) + require.Nil(ch) + require.Nil(unsubscribe) + + reportQueryFinished, err := tracker.ReportQueryStarted(testQueryId) + require.Nil(err, "should be able to report start of a query to be tracked") + + testProgress1 := &clickhouse.Progress{ + Rows: 10, + Bytes: 20, + TotalRows: 100, + Elapsed: 20 * time.Millisecond, + } + err = tracker.ReportQueryProgress(testQueryId, testProgress1) + require.Nil(err, "should be able to report progress after query has started") + + ch, unsubscribe, err = tracker.SubscribeToQueryProgress(testQueryId) + require.Nil(err, "should be able to subscribe to query progress updates after query started") + require.NotNil(ch) + require.NotNil(unsubscribe) + + expectedProgress := v3.QueryProgress{} + updateQueryProgress(&expectedProgress, testProgress1) + require.Equal(expectedProgress.ReadRows, testProgress1.Rows) + select { + case qp := <-ch: + require.Equal(qp, expectedProgress) + default: + require.Fail("should receive latest query progress state immediately after subscription") + } + select { + case _ = <-ch: + require.Fail("should have had only one pending update at this point") + default: + } + + testProgress2 := &clickhouse.Progress{ + Rows: 20, + Bytes: 40, + TotalRows: 100, + Elapsed: 40 * time.Millisecond, + } + err = tracker.ReportQueryProgress(testQueryId, testProgress2) + require.Nil(err, "should be able to report progress multiple times while query is in progress") + + updateQueryProgress(&expectedProgress, testProgress2) + select { + case qp := <-ch: + require.Equal(qp, expectedProgress) + default: + require.Fail("should receive updates whenever new progress updates get reported to tracker") + } + select { + case _ = <-ch: + require.Fail("should have had only one pending update at this point") + default: + } + + reportQueryFinished() + select { + case _, isSubscriptionChannelOpen := <-ch: + require.False(isSubscriptionChannelOpen, "subscription channels should get closed after query finishes") + default: + require.Fail("subscription channels should get closed after query finishes") + } + + err = tracker.ReportQueryProgress(testQueryId, testProgress) + require.NotNil(err, "shouldn't be able to report query progress after query has finished") + require.Equal(err.Type(), model.ErrorNotFound) + + ch, unsubscribe, err = tracker.SubscribeToQueryProgress(testQueryId) + require.NotNil(err, "shouldn't be able to subscribe for progress updates after query has finished") + require.Equal(err.Type(), model.ErrorNotFound) + require.Nil(ch) + require.Nil(unsubscribe) +} diff --git a/pkg/query-service/app/clickhouseReader/reader.go b/pkg/query-service/app/clickhouseReader/reader.go index cc27b4f2eb..b34780c37d 100644 --- a/pkg/query-service/app/clickhouseReader/reader.go +++ b/pkg/query-service/app/clickhouseReader/reader.go @@ -41,6 +41,7 @@ import ( promModel "github.com/prometheus/common/model" "go.uber.org/zap" + queryprogress "go.signoz.io/signoz/pkg/query-service/app/clickhouseReader/query_progress" "go.signoz.io/signoz/pkg/query-service/app/dashboards" "go.signoz.io/signoz/pkg/query-service/app/explorer" "go.signoz.io/signoz/pkg/query-service/app/logs" @@ -122,6 +123,7 @@ type ClickHouseReader struct { queryEngine *promql.Engine remoteStorage *remote.Storage fanoutStorage *storage.Storage + queryProgressTracker queryprogress.QueryProgressTracker promConfigFile string promConfig *config.Config @@ -215,6 +217,7 @@ func NewReaderFromClickhouseConnection( promConfigFile: configFile, featureFlags: featureFlag, cluster: cluster, + queryProgressTracker: queryprogress.NewQueryProgressTracker(), } } @@ -4706,6 +4709,30 @@ func (r *ClickHouseReader) GetTimeSeriesResultV3(ctx context.Context, query stri defer utils.Elapsed("GetTimeSeriesResultV3", ctxArgs)() + // Hook up query progress reporting if requested. + queryId := ctx.Value("queryId") + if queryId != nil { + qid, ok := queryId.(string) + if !ok { + zap.L().Error("GetTimeSeriesResultV3: queryId in ctx not a string as expected", zap.Any("queryId", queryId)) + + } else { + ctx = clickhouse.Context(ctx, clickhouse.WithProgress( + func(p *clickhouse.Progress) { + go func() { + err := r.queryProgressTracker.ReportQueryProgress(qid, p) + if err != nil { + zap.L().Error( + "Couldn't report query progress", + zap.String("queryId", qid), zap.Error(err), + ) + } + }() + }, + )) + } + } + rows, err := r.db.Query(ctx, query) if err != nil { @@ -5464,3 +5491,15 @@ func (r *ClickHouseReader) GetMinAndMaxTimestampForTraceID(ctx context.Context, return minTime.UnixNano(), maxTime.UnixNano(), nil } + +func (r *ClickHouseReader) ReportQueryStartForProgressTracking( + queryId string, +) (func(), *model.ApiError) { + return r.queryProgressTracker.ReportQueryStarted(queryId) +} + +func (r *ClickHouseReader) SubscribeToQueryProgress( + queryId string, +) (<-chan v3.QueryProgress, func(), *model.ApiError) { + return r.queryProgressTracker.SubscribeToQueryProgress(queryId) +} diff --git a/pkg/query-service/app/http_handler.go b/pkg/query-service/app/http_handler.go index 07a1814ac7..97b5f0de2a 100644 --- a/pkg/query-service/app/http_handler.go +++ b/pkg/query-service/app/http_handler.go @@ -9,6 +9,7 @@ import ( "io" "math" "net/http" + "net/url" "regexp" "slices" "strconv" @@ -18,6 +19,7 @@ import ( "time" "github.com/gorilla/mux" + "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" _ "github.com/mattn/go-sqlite3" "github.com/prometheus/prometheus/promql" @@ -101,6 +103,9 @@ type APIHandler struct { // at the moment, we mark the app ready when the first user // is registers. SetupCompleted bool + + // Websocket connection upgrader + Upgrader *websocket.Upgrader } type APIHandlerOpts struct { @@ -207,6 +212,29 @@ func NewAPIHandler(opts APIHandlerOpts) (*APIHandler, error) { // to signup signoz through invite link only. aH.SetupCompleted = true } + + aH.Upgrader = &websocket.Upgrader{ + // Same-origin check is the server's responsibility in websocket spec. + CheckOrigin: func(r *http.Request) bool { + // Based on the default CheckOrigin implementation in websocket package. + originHeader := r.Header.Get("Origin") + if len(originHeader) < 1 { + return false + } + origin, err := url.Parse(originHeader) + if err != nil { + return false + } + + // Allow cross origin websocket connections on localhost + if strings.HasPrefix(origin.Host, "localhost") { + return true + } + + return origin.Host == r.Host + }, + } + return aH, nil } @@ -305,6 +333,9 @@ func (aH *APIHandler) RegisterQueryRangeV3Routes(router *mux.Router, am *AuthMid subRouter.HandleFunc("/filter_suggestions", am.ViewAccess(aH.getQueryBuilderSuggestions)).Methods(http.MethodGet) + // websocket handler for query progress + subRouter.HandleFunc("/query_progress", am.ViewAccess(aH.GetQueryProgressUpdates)).Methods(http.MethodGet) + // live logs subRouter.HandleFunc("/logs/livetail", am.ViewAccess(aH.liveTailLogs)).Methods(http.MethodGet) } @@ -3517,6 +3548,24 @@ func (aH *APIHandler) queryRangeV3(ctx context.Context, queryRangeParams *v3.Que } } + // Hook up query progress tracking if requested + queryIdHeader := r.Header.Get("X-SIGNOZ-QUERY-ID") + if len(queryIdHeader) > 0 { + ctx = context.WithValue(ctx, "queryId", queryIdHeader) + + onQueryFinished, err := aH.reader.ReportQueryStartForProgressTracking(queryIdHeader) + if err != nil { + zap.L().Error( + "couldn't report query start for progress tracking", + zap.String("queryId", queryIdHeader), zap.Error(err), + ) + } else { + defer func() { + go onQueryFinished() + }() + } + } + result, errQuriesByName, err = aH.querier.QueryRange(ctx, queryRangeParams, spanKeys) if err != nil { @@ -3666,6 +3715,77 @@ func (aH *APIHandler) QueryRangeV3(w http.ResponseWriter, r *http.Request) { aH.queryRangeV3(r.Context(), queryRangeParams, w, r) } +func (aH *APIHandler) GetQueryProgressUpdates(w http.ResponseWriter, r *http.Request) { + // Upgrade connection to websocket, sending back the requested protocol + // value for sec-websocket-protocol + // + // Since js websocket API doesn't allow setting headers, this header is often + // used for passing auth tokens. As per websocket spec the connection will only + // succeed if the requested `Sec-Websocket-Protocol` is sent back as a header + // in the upgrade response (signifying that the protocol is supported by the server). + upgradeResponseHeaders := http.Header{} + requestedProtocol := r.Header.Get("Sec-WebSocket-Protocol") + if len(requestedProtocol) > 0 { + upgradeResponseHeaders.Add("Sec-WebSocket-Protocol", requestedProtocol) + } + + c, err := aH.Upgrader.Upgrade(w, r, upgradeResponseHeaders) + if err != nil { + RespondError(w, model.InternalError(fmt.Errorf( + "couldn't upgrade connection: %w", err, + )), nil) + return + } + defer c.Close() + + // Websocket upgrade complete. Subscribe to query progress and send updates to client + // + // Note: we handle any subscription problems (queryId query param missing or query already complete etc) + // after the websocket connection upgrade by closing the channel. + // The other option would be to handle the errors before websocket upgrade by sending an + // error response instead of the upgrade response, but that leads to a generic websocket + // connection failure on the client. + + queryId := r.URL.Query().Get("q") + + progressCh, unsubscribe, apiErr := aH.reader.SubscribeToQueryProgress(queryId) + if apiErr != nil { + // Shouldn't happen unless query progress requested after query finished + zap.L().Warn( + "couldn't subscribe to query progress", + zap.String("queryId", queryId), zap.Any("error", err), + ) + return + } + defer func() { go unsubscribe() }() + + for queryProgress := range progressCh { + msg, err := json.Marshal(queryProgress) + if err != nil { + zap.L().Error( + "failed to serialize progress message", + zap.String("queryId", queryId), zap.Any("progress", queryProgress), zap.Error(err), + ) + continue + } + + err = c.WriteMessage(websocket.TextMessage, msg) + if err != nil { + zap.L().Error( + "failed to write progress msg to websocket", + zap.String("queryId", queryId), zap.String("msg", string(msg)), zap.Error(err), + ) + break + + } else { + zap.L().Debug( + "wrote progress msg to websocket", + zap.String("queryId", queryId), zap.String("msg", string(msg)), zap.Error(err), + ) + } + } +} + func (aH *APIHandler) liveTailLogs(w http.ResponseWriter, r *http.Request) { // get the param from url and add it to body diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index 254b6f6f78..3f5352edbd 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -1,6 +1,7 @@ package app import ( + "bufio" "bytes" "context" "encoding/json" @@ -262,7 +263,7 @@ func (s *Server) createPrivateServer(api *APIHandler) (*http.Server, error) { // ip here for alert manager AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-SIGNOZ-QUERY-ID", "Sec-WebSocket-Protocol"}, }) handler := c.Handler(r) @@ -308,7 +309,7 @@ func (s *Server) createPublicServer(api *APIHandler) (*http.Server, error) { c := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "cache-control"}, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "cache-control", "X-SIGNOZ-QUERY-ID", "Sec-WebSocket-Protocol"}, }) handler := c.Handler(r) @@ -423,6 +424,15 @@ func (lrw *loggingResponseWriter) Flush() { lrw.ResponseWriter.(http.Flusher).Flush() } +// Support websockets +func (lrw *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := lrw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("hijack not supported") + } + return h.Hijack() +} + func extractQueryRangeV3Data(path string, r *http.Request) (map[string]interface{}, bool) { pathToExtractBodyFrom := "/api/v3/query_range" diff --git a/pkg/query-service/auth/jwt.go b/pkg/query-service/auth/jwt.go index f57bb2ae18..637119fe18 100644 --- a/pkg/query-service/auth/jwt.go +++ b/pkg/query-service/auth/jwt.go @@ -20,7 +20,7 @@ var ( ) func ParseJWT(jwtStr string) (jwt.MapClaims, error) { - // TODO[@vikrantgupta25] : to update this to the claims check function for better integrity of JWT + // TODO[@vikrantgupta25] : to update this to the claims check function for better integrity of JWT // reference - https://pkg.go.dev/github.com/golang-jwt/jwt/v5#Parser.ParseWithClaims token, err := jwt.Parse(jwtStr, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { @@ -53,7 +53,7 @@ func validateUser(tok string) (*model.UserPayload, error) { var orgId string if claims["orgId"] != nil { - orgId = claims["orgId"].(string) + orgId = claims["orgId"].(string) } return &model.UserPayload{ @@ -83,7 +83,22 @@ func ExtractJwtFromContext(ctx context.Context) (string, bool) { } func ExtractJwtFromRequest(r *http.Request) (string, error) { - return jwtmiddleware.FromAuthHeader(r) + authHeaderJwt, err := jwtmiddleware.FromAuthHeader(r) + if err != nil { + return "", err + } + + if len(authHeaderJwt) > 0 { + return authHeaderJwt, nil + } + + // We expect websocket connections to send auth JWT in the + // `Sec-Websocket-Protocol` header. + // + // The standard js websocket API doesn't allow setting headers + // other than the `Sec-WebSocket-Protocol` header, which is often + // used for auth purposes as a result. + return r.Header.Get("Sec-WebSocket-Protocol"), nil } func ExtractUserIdFromContext(ctx context.Context) (string, error) { diff --git a/pkg/query-service/interfaces/interface.go b/pkg/query-service/interfaces/interface.go index 819f452f3b..4086947d4f 100644 --- a/pkg/query-service/interfaces/interface.go +++ b/pkg/query-service/interfaces/interface.go @@ -117,6 +117,10 @@ type Reader interface { GetAvgResolutionTimeByInterval(ctx context.Context, ruleID string, params *v3.QueryRuleStateHistory) (*v3.Series, error) ReadRuleStateHistoryTopContributorsByRuleID(ctx context.Context, ruleID string, params *v3.QueryRuleStateHistory) ([]v3.RuleStateHistoryContributor, error) GetMinAndMaxTimestampForTraceID(ctx context.Context, traceID []string) (int64, int64, error) + + // Query Progress tracking helpers. + ReportQueryStartForProgressTracking(queryId string) (reportQueryFinished func(), err *model.ApiError) + SubscribeToQueryProgress(queryId string) (<-chan v3.QueryProgress, func(), *model.ApiError) } type Querier interface { diff --git a/pkg/query-service/model/v3/v3.go b/pkg/query-service/model/v3/v3.go index 3b5e042c6d..6f7881a336 100644 --- a/pkg/query-service/model/v3/v3.go +++ b/pkg/query-service/model/v3/v3.go @@ -1247,3 +1247,11 @@ type Stats struct { CurrentAvgResolutionTimeSeries *Series `json:"currentAvgResolutionTimeSeries"` PastAvgResolutionTimeSeries *Series `json:"pastAvgResolutionTimeSeries"` } + +type QueryProgress struct { + ReadRows uint64 `json:"read_rows"` + + ReadBytes uint64 `json:"read_bytes"` + + ElapsedMs uint64 `json:"elapsed_ms"` +}