diff --git a/.gitignore b/.gitignore index 8fe54dcf3d..2bd9238255 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,6 @@ e2e/.auth # go vendor/ **/main/** + +# git-town +.git-branches.toml diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index 082ddcd358..d77e71cc28 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -376,6 +376,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler) (*http.Server, e }, nil } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // loggingMiddleware is used for logging public api calls func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -387,6 +388,7 @@ func loggingMiddleware(next http.Handler) http.Handler { }) } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // loggingMiddlewarePrivate is used for logging private api calls // from internal services like alert manager func loggingMiddlewarePrivate(next http.Handler) http.Handler { @@ -399,27 +401,32 @@ func loggingMiddlewarePrivate(next http.Handler) http.Handler { }) } +// TODO(remove): Implemented at pkg/http/middleware/logging.go type loggingResponseWriter struct { http.ResponseWriter statusCode int } +// TODO(remove): Implemented at pkg/http/middleware/logging.go func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { // WriteHeader(int) is not called if our response implicitly returns 200 OK, so // we default to that status code. return &loggingResponseWriter{w, http.StatusOK} } +// TODO(remove): Implemented at pkg/http/middleware/logging.go func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // Flush implements the http.Flush interface. func (lrw *loggingResponseWriter) Flush() { lrw.ResponseWriter.(http.Flusher).Flush() } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // Support websockets func (lrw *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { h, ok := lrw.ResponseWriter.(http.Hijacker) @@ -565,6 +572,7 @@ func (s *Server) analyticsMiddleware(next http.Handler) http.Handler { }) } +// TODO(remove): Implemented at pkg/http/middleware/timeout.go func setTimeoutMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/pkg/http/doc.go b/pkg/http/doc.go new file mode 100644 index 0000000000..5e62c61b15 --- /dev/null +++ b/pkg/http/doc.go @@ -0,0 +1,3 @@ +// package http contains all http related functions such +// as servers, middlewares, routers and renders. +package http diff --git a/pkg/http/middleware/doc.go b/pkg/http/middleware/doc.go new file mode 100644 index 0000000000..911746777c --- /dev/null +++ b/pkg/http/middleware/doc.go @@ -0,0 +1,2 @@ +// package middleware contains an implementation of all middlewares. +package middleware diff --git a/pkg/http/middleware/logging.go b/pkg/http/middleware/logging.go new file mode 100644 index 0000000000..ef755f6648 --- /dev/null +++ b/pkg/http/middleware/logging.go @@ -0,0 +1,72 @@ +package middleware + +import ( + "bytes" + "net" + "net/http" + "time" + + "github.com/gorilla/mux" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.uber.org/zap" +) + +const ( + logMessage string = "::RECEIVED-REQUEST::" +) + +type Logging struct { + logger *zap.Logger +} + +func NewLogging(logger *zap.Logger) *Logging { + if logger == nil { + panic("cannot build logging, logger is empty") + } + + return &Logging{ + logger: logger.Named(pkgname), + } +} + +func (middleware *Logging) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + ctx := req.Context() + start := time.Now() + host, port, _ := net.SplitHostPort(req.Host) + path, err := mux.CurrentRoute(req).GetPathTemplate() + if err != nil { + path = req.URL.Path + } + + fields := []zap.Field{ + zap.Any("context", ctx), + zap.String(string(semconv.ClientAddressKey), req.RemoteAddr), + zap.String(string(semconv.UserAgentOriginalKey), req.UserAgent()), + zap.String(string(semconv.ServerAddressKey), host), + zap.String(string(semconv.ServerPortKey), port), + zap.Int64(string(semconv.HTTPRequestSizeKey), req.ContentLength), + zap.String(string(semconv.HTTPRouteKey), path), + } + + buf := new(bytes.Buffer) + writer := newBadResponseLoggingWriter(rw, buf) + next.ServeHTTP(writer, req) + + statusCode, err := writer.StatusCode(), writer.WriteError() + fields = append(fields, + zap.Int(string(semconv.HTTPResponseStatusCodeKey), statusCode), + zap.Duration(string(semconv.HTTPServerRequestDurationName), time.Since(start)), + ) + if err != nil { + fields = append(fields, zap.Error(err)) + middleware.logger.Error(logMessage, fields...) + } else { + if buf.Len() != 0 { + fields = append(fields, zap.String("response.body", buf.String())) + } + + middleware.logger.Info(logMessage, fields...) + } + }) +} diff --git a/pkg/http/middleware/middleware.go b/pkg/http/middleware/middleware.go new file mode 100644 index 0000000000..6313089aa4 --- /dev/null +++ b/pkg/http/middleware/middleware.go @@ -0,0 +1,20 @@ +package middleware + +import "net/http" + +const ( + pkgname string = "go.signoz.io/pkg/http/middleware" +) + +// Wrapper is an interface implemented by all middlewares +type Wrapper interface { + Wrap(http.Handler) http.Handler +} + +// WrapperFunc is to Wrapper as http.HandlerFunc is to http.Handler +type WrapperFunc func(http.Handler) http.Handler + +// WrapperFunc implements Wrapper +func (m WrapperFunc) Wrap(next http.Handler) http.Handler { + return m(next) +} diff --git a/pkg/http/middleware/response.go b/pkg/http/middleware/response.go new file mode 100644 index 0000000000..deb0f3dd81 --- /dev/null +++ b/pkg/http/middleware/response.go @@ -0,0 +1,122 @@ +package middleware + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" +) + +const ( + maxResponseBodyInLogs = 4096 // At most 4k bytes from response bodies in our logs. +) + +type badResponseLoggingWriter interface { + http.ResponseWriter + // Get the status code. + StatusCode() int + // Get the error while writing. + WriteError() error +} + +func newBadResponseLoggingWriter(rw http.ResponseWriter, buffer io.Writer) badResponseLoggingWriter { + b := nonFlushingBadResponseLoggingWriter{ + rw: rw, + buffer: buffer, + logBody: false, + bodyBytesLeft: maxResponseBodyInLogs, + statusCode: http.StatusOK, + } + + if f, ok := rw.(http.Flusher); ok { + return &flushingBadResponseLoggingWriter{b, f} + } + + return &b +} + +type nonFlushingBadResponseLoggingWriter struct { + rw http.ResponseWriter + buffer io.Writer + logBody bool + bodyBytesLeft int + statusCode int + writeError error // The error returned when downstream Write() fails. +} + +// Extends nonFlushingBadResponseLoggingWriter that implements http.Flusher +type flushingBadResponseLoggingWriter struct { + nonFlushingBadResponseLoggingWriter + f http.Flusher +} + +// Unwrap method is used by http.ResponseController to get access to original http.ResponseWriter. +func (writer *nonFlushingBadResponseLoggingWriter) Unwrap() http.ResponseWriter { + return writer.rw +} + +// Header returns the header map that will be sent by WriteHeader. +// Implements ResponseWriter. +func (writer *nonFlushingBadResponseLoggingWriter) Header() http.Header { + return writer.rw.Header() +} + +// WriteHeader writes the HTTP response header. +func (writer *nonFlushingBadResponseLoggingWriter) WriteHeader(statusCode int) { + writer.statusCode = statusCode + if statusCode >= 500 || statusCode == 400 { + writer.logBody = true + } + writer.rw.WriteHeader(statusCode) +} + +// Writes HTTP response data. +func (writer *nonFlushingBadResponseLoggingWriter) Write(data []byte) (int, error) { + if writer.statusCode == 0 { + // WriteHeader has (probably) not been called, so we need to call it with StatusOK to fulfill the interface contract. + // https://godoc.org/net/http#ResponseWriter + writer.WriteHeader(http.StatusOK) + } + n, err := writer.rw.Write(data) + if writer.logBody { + writer.captureResponseBody(data) + } + if err != nil { + writer.writeError = err + } + return n, err +} + +// Hijack hijacks the first response writer that is a Hijacker. +func (writer *nonFlushingBadResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := writer.rw.(http.Hijacker) + if ok { + return hj.Hijack() + } + return nil, nil, fmt.Errorf("cannot cast underlying response writer to Hijacker") +} + +func (writer *nonFlushingBadResponseLoggingWriter) StatusCode() int { + return writer.statusCode +} + +func (writer *nonFlushingBadResponseLoggingWriter) WriteError() error { + return writer.writeError +} + +func (writer *flushingBadResponseLoggingWriter) Flush() { + writer.f.Flush() +} + +func (writer *nonFlushingBadResponseLoggingWriter) captureResponseBody(data []byte) { + if len(data) > writer.bodyBytesLeft { + _, _ = writer.buffer.Write(data[:writer.bodyBytesLeft]) + _, _ = io.WriteString(writer.buffer, "...") + writer.bodyBytesLeft = 0 + writer.logBody = false + } else { + _, _ = writer.buffer.Write(data) + writer.bodyBytesLeft -= len(data) + } +} diff --git a/pkg/http/middleware/timeout.go b/pkg/http/middleware/timeout.go new file mode 100644 index 0000000000..50e9d82f22 --- /dev/null +++ b/pkg/http/middleware/timeout.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "context" + "net/http" + "strings" + "time" + + "go.uber.org/zap" +) + +const ( + headerName string = "timeout" +) + +type Timeout struct { + logger *zap.Logger + excluded map[string]struct{} + // The default timeout + defaultTimeout time.Duration + // The max allowed timeout + maxTimeout time.Duration +} + +func NewTimeout(logger *zap.Logger, excluded map[string]struct{}, defaultTimeout time.Duration, maxTimeout time.Duration) *Timeout { + if logger == nil { + panic("cannot build timeout, logger is empty") + } + + if excluded == nil { + excluded = make(map[string]struct{}) + } + + if defaultTimeout.Seconds() == 0 { + defaultTimeout = 60 * time.Second + } + + if maxTimeout == 0 { + maxTimeout = 600 * time.Second + } + + return &Timeout{ + logger: logger.Named(pkgname), + excluded: excluded, + defaultTimeout: defaultTimeout, + maxTimeout: maxTimeout, + } +} + +func (middleware *Timeout) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if _, ok := middleware.excluded[req.URL.Path]; !ok { + actual := middleware.defaultTimeout + incoming := req.Header.Get(headerName) + if incoming != "" { + parsed, err := time.ParseDuration(strings.TrimSpace(incoming) + "s") + if err != nil { + middleware.logger.Warn("cannot parse timeout in header, using default timeout", zap.String("timeout", incoming), zap.Error(err), zap.Any("context", req.Context())) + } else { + if parsed > middleware.maxTimeout { + actual = middleware.maxTimeout + } else { + actual = parsed + } + } + } + + ctx, cancel := context.WithTimeout(req.Context(), actual) + defer cancel() + + req = req.WithContext(ctx) + next.ServeHTTP(rw, req) + return + } + + next.ServeHTTP(rw, req) + }) +} diff --git a/pkg/http/middleware/timeout_test.go b/pkg/http/middleware/timeout_test.go new file mode 100644 index 0000000000..2575bfe7d9 --- /dev/null +++ b/pkg/http/middleware/timeout_test.go @@ -0,0 +1,80 @@ +package middleware + +import ( + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestTimeout(t *testing.T) { + t.Parallel() + + writeTimeout := 6 * time.Second + defaultTimeout := 2 * time.Second + maxTimeout := 4 * time.Second + m := NewTimeout(zap.NewNop(), map[string]struct{}{"/excluded": {}}, defaultTimeout, maxTimeout) + + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + server := &http.Server{ + WriteTimeout: writeTimeout, + Handler: m.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, ok := r.Context().Deadline() + if ok { + <-r.Context().Done() + require.Error(t, r.Context().Err()) + } + w.WriteHeader(204) + })), + } + + go func() { + require.NoError(t, server.Serve(listener)) + }() + + testCases := []struct { + name string + wait time.Duration + header string + path string + }{ + { + name: "WaitTillNoTimeoutForExcludedPath", + wait: 1 * time.Nanosecond, + header: "4", + path: "excluded", + }, + { + name: "WaitTillHeaderTimeout", + wait: 3 * time.Second, + header: "3", + path: "header-timeout", + }, + { + name: "WaitTillMaxTimeout", + wait: 4 * time.Second, + header: "5", + path: "max-timeout", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + start := time.Now() + req, err := http.NewRequest("GET", "http://"+listener.Addr().String()+"/"+tc.path, nil) + require.NoError(t, err) + req.Header.Add(headerName, tc.header) + + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) + + // confirm that we waited at least till the "wait" time + require.GreaterOrEqual(t, time.Since(start), tc.wait) + }) + } +} diff --git a/pkg/http/server/config.go b/pkg/http/server/config.go new file mode 100644 index 0000000000..fb8eb5be11 --- /dev/null +++ b/pkg/http/server/config.go @@ -0,0 +1,27 @@ +package server + +import ( + "go.signoz.io/signoz/pkg/confmap" +) + +// Config satisfies the confmap.Config interface +var _ confmap.Config = (*Config)(nil) + +// Config holds the configuration for http. +type Config struct { + //Address specifies the TCP address for the server to listen on, in the form "host:port". + // If empty, ":http" (port 80) is used. The service names are defined in RFC 6335 and assigned by IANA. + // See net.Dial for details of the address format. + Address string `mapstructure:"address"` +} + +func (c *Config) NewWithDefaults() confmap.Config { + return &Config{ + Address: "0.0.0.0:8080", + } + +} + +func (c *Config) Validate() error { + return nil +} diff --git a/pkg/http/server/doc.go b/pkg/http/server/doc.go new file mode 100644 index 0000000000..9bf12b5ae5 --- /dev/null +++ b/pkg/http/server/doc.go @@ -0,0 +1,2 @@ +// package server contains an implementation of the http server. +package server diff --git a/pkg/http/server/server.go b/pkg/http/server/server.go new file mode 100644 index 0000000000..fbeca1c3a9 --- /dev/null +++ b/pkg/http/server/server.go @@ -0,0 +1,79 @@ +package server + +import ( + "context" + "fmt" + "net/http" + "time" + + "go.signoz.io/signoz/pkg/registry" + "go.uber.org/zap" +) + +var _ registry.NamedService = (*Server)(nil) + +type Server struct { + srv *http.Server + logger *zap.Logger + handler http.Handler + cfg Config + name string +} + +func New(logger *zap.Logger, name string, cfg Config, handler http.Handler) (*Server, error) { + if handler == nil { + return nil, fmt.Errorf("cannot build http server, handler is required") + } + + if logger == nil { + return nil, fmt.Errorf("cannot build http server, logger is required") + } + + if name == "" { + return nil, fmt.Errorf("cannot build http server, name is required") + } + + srv := &http.Server{ + Addr: cfg.Address, + Handler: handler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, + } + + return &Server{ + srv: srv, + logger: logger.Named("go.signoz.io/pkg/http/server"), + handler: handler, + cfg: cfg, + name: name, + }, nil +} + +func (server *Server) Name() string { + return server.name +} + +func (server *Server) Start(ctx context.Context) error { + server.logger.Info("starting http server", zap.String("address", server.srv.Addr)) + if err := server.srv.ListenAndServe(); err != nil { + if err != http.ErrServerClosed { + server.logger.Error("failed to start server", zap.Error(err), zap.Any("context", ctx)) + return err + } + } + return nil +} + +func (server *Server) Stop(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := server.srv.Shutdown(ctx); err != nil { + server.logger.Error("failed to stop server", zap.Error(err), zap.Any("context", ctx)) + return err + } + + server.logger.Info("server stopped gracefully", zap.Any("context", ctx)) + return nil +} diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index 4fb4d9ad22..77caa9170b 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -322,6 +322,7 @@ func (s *Server) createPublicServer(api *APIHandler) (*http.Server, error) { }, nil } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // loggingMiddleware is used for logging public api calls func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -392,6 +393,7 @@ func LogCommentEnricher(next http.Handler) http.Handler { }) } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // loggingMiddlewarePrivate is used for logging private api calls // from internal services like alert manager func loggingMiddlewarePrivate(next http.Handler) http.Handler { @@ -404,27 +406,32 @@ func loggingMiddlewarePrivate(next http.Handler) http.Handler { }) } +// TODO(remove): Implemented at pkg/http/middleware/logging.go type loggingResponseWriter struct { http.ResponseWriter statusCode int } +// TODO(remove): Implemented at pkg/http/middleware/logging.go func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { // WriteHeader(int) is not called if our response implicitly returns 200 OK, so // we default to that status code. return &loggingResponseWriter{w, http.StatusOK} } +// TODO(remove): Implemented at pkg/http/middleware/logging.go func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // Flush implements the http.Flush interface. func (lrw *loggingResponseWriter) Flush() { lrw.ResponseWriter.(http.Flusher).Flush() } +// TODO(remove): Implemented at pkg/http/middleware/logging.go // Support websockets func (lrw *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { h, ok := lrw.ResponseWriter.(http.Hijacker) @@ -538,6 +545,7 @@ func (s *Server) analyticsMiddleware(next http.Handler) http.Handler { }) } +// TODO(remove): Implemented at pkg/http/middleware/timeout.go func getRouteContextTimeout(overrideTimeout string) time.Duration { var timeout time.Duration var err error @@ -554,6 +562,7 @@ func getRouteContextTimeout(overrideTimeout string) time.Duration { return constants.ContextTimeout } +// TODO(remove): Implemented at pkg/http/middleware/timeout.go func setTimeoutMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/pkg/query-service/app/server_test.go b/pkg/query-service/app/server_test.go index afdc06fa33..49fc6191fb 100644 --- a/pkg/query-service/app/server_test.go +++ b/pkg/query-service/app/server_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) +// TODO(remove): Implemented at pkg/http/middleware/timeout_test.go func TestGetRouteContextTimeout(t *testing.T) { var testGetRouteContextTimeoutData = []struct { Name string diff --git a/pkg/registry/doc.go b/pkg/registry/doc.go new file mode 100644 index 0000000000..ff2debbefe --- /dev/null +++ b/pkg/registry/doc.go @@ -0,0 +1,3 @@ +// package registry contains a simple implementation of https://github.com/google/guava/wiki/ServiceExplained +// Here the the "ServiceManager" is called the "Registry" +package registry diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go new file mode 100644 index 0000000000..850a4389a9 --- /dev/null +++ b/pkg/registry/registry.go @@ -0,0 +1,84 @@ +package registry + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + + "go.uber.org/zap" +) + +type Registry struct { + services []NamedService + logger *zap.Logger + startCh chan error + stopCh chan error +} + +// New creates a new registry of services. It needs at least one service in the input. +func New(logger *zap.Logger, services ...NamedService) (*Registry, error) { + if logger == nil { + return nil, fmt.Errorf("cannot build registry, logger is required") + } + + if len(services) == 0 { + return nil, fmt.Errorf("cannot build registry, at least one service is required") + } + + return &Registry{ + logger: logger.Named("go.signoz.io/pkg/registry"), + services: services, + startCh: make(chan error, 1), + stopCh: make(chan error, len(services)), + }, nil +} + +func (r *Registry) Start(ctx context.Context) error { + for _, s := range r.services { + go func(s Service) { + err := s.Start(ctx) + r.startCh <- err + }(s) + } + + return nil +} + +func (r *Registry) Wait(ctx context.Context) error { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-ctx.Done(): + r.logger.Info("caught context error, exiting", zap.Any("context", ctx)) + case s := <-interrupt: + r.logger.Info("caught interrupt signal, exiting", zap.Any("context", ctx), zap.Any("signal", s)) + case err := <-r.startCh: + r.logger.Info("caught service error, exiting", zap.Any("context", ctx), zap.Error(err)) + return err + } + + return nil +} + +func (r *Registry) Stop(ctx context.Context) error { + for _, s := range r.services { + go func(s Service) { + err := s.Stop(ctx) + r.stopCh <- err + }(s) + } + + errs := make([]error, len(r.services)) + for i := 0; i < len(r.services); i++ { + err := <-r.stopCh + if err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/pkg/registry/registry_test.go b/pkg/registry/registry_test.go new file mode 100644 index 0000000000..12ae1d8862 --- /dev/null +++ b/pkg/registry/registry_test.go @@ -0,0 +1,56 @@ +package registry + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestRegistryWith2HttpServers(t *testing.T) { + http1, err := newHttpService("http1") + require.NoError(t, err) + + http2, err := newHttpService("http2") + require.NoError(t, err) + + registry, err := New(zap.NewNop(), http1, http2) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, registry.Start(ctx)) + require.NoError(t, registry.Wait(ctx)) + require.NoError(t, registry.Stop(ctx)) + }() + cancel() + + wg.Wait() +} + +func TestRegistryWith2HttpServersWithoutWait(t *testing.T) { + http1, err := newHttpService("http1") + require.NoError(t, err) + + http2, err := newHttpService("http2") + require.NoError(t, err) + + registry, err := New(zap.NewNop(), http1, http2) + require.NoError(t, err) + + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, registry.Start(ctx)) + require.NoError(t, registry.Stop(ctx)) + }() + + wg.Wait() +} diff --git a/pkg/registry/service.go b/pkg/registry/service.go new file mode 100644 index 0000000000..38df4f0a4f --- /dev/null +++ b/pkg/registry/service.go @@ -0,0 +1,16 @@ +package registry + +import "context" + +type Service interface { + // Starts a service. The service should return an error if it cannot be started. + Start(context.Context) error + // Stops a service. + Stop(context.Context) error +} + +type NamedService interface { + // Identifier of a service. It should be unique across all services. + Name() string + Service +} diff --git a/pkg/registry/service_test.go b/pkg/registry/service_test.go new file mode 100644 index 0000000000..dc0621e962 --- /dev/null +++ b/pkg/registry/service_test.go @@ -0,0 +1,49 @@ +package registry + +import ( + "context" + "net" + "net/http" +) + +var _ NamedService = (*httpService)(nil) + +type httpService struct { + Listener net.Listener + Server *http.Server + name string +} + +func newHttpService(name string) (*httpService, error) { + return &httpService{ + name: name, + Server: &http.Server{}, + }, nil +} + +func (service *httpService) Name() string { + return service.name +} + +func (service *httpService) Start(ctx context.Context) error { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return err + } + service.Listener = listener + + if err := service.Server.Serve(service.Listener); err != nil { + if err != http.ErrServerClosed { + return err + } + } + return nil +} + +func (service *httpService) Stop(ctx context.Context) error { + if err := service.Server.Shutdown(ctx); err != nil { + return err + } + + return nil +}