diff --git a/obsidian/impl/v1/server.go b/obsidian/impl/v1/server.go index 8c08eb0..1b6846d 100644 --- a/obsidian/impl/v1/server.go +++ b/obsidian/impl/v1/server.go @@ -80,9 +80,7 @@ func (s *Server) interceptor(ctx context.Context, req interface{}, usi *grpc.Una func (s *Server) Run() { res := utils.NewGRPCServer( &utils.GRPCOptions{ - ServerOptions: []grpc.ServerOption{ - grpc.UnaryInterceptor(s.interceptor), - }, + Interceptor: s.interceptor, }, func(r *utils.Register) { endpoints := []utils.GrpcEndpointRegister{ diff --git a/peridot/impl/v1/server.go b/peridot/impl/v1/server.go index 39f4c6e..9b82ab0 100644 --- a/peridot/impl/v1/server.go +++ b/peridot/impl/v1/server.go @@ -156,9 +156,9 @@ func (s *Server) Run() { DialOptions: []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), }, + Interceptor: s.interceptor, + ServerInterceptor: s.serverInterceptor, ServerOptions: []grpc.ServerOption{ - grpc.UnaryInterceptor(s.interceptor), - grpc.StreamInterceptor(s.serverInterceptor), grpc.MaxRecvMsgSize(1024 * 1024 * 1024), }, }, diff --git a/peridot/keykeeper/v1/server.go b/peridot/keykeeper/v1/server.go index 9d14225..f8e551e 100644 --- a/peridot/keykeeper/v1/server.go +++ b/peridot/keykeeper/v1/server.go @@ -174,10 +174,8 @@ func (s *Server) Run() { defer wg.Done() res := utils.NewGRPCServer( &utils.GRPCOptions{ - Timeout: &timeout, - ServerOptions: []grpc.ServerOption{ - grpc.UnaryInterceptor(s.interceptor), - }, + Timeout: &timeout, + Interceptor: s.interceptor, }, func(r *utils.Register) { endpoints := []utils.GrpcEndpointRegister{ diff --git a/peridot/yumrepofs/v1/server.go b/peridot/yumrepofs/v1/server.go index d865737..b062f4d 100644 --- a/peridot/yumrepofs/v1/server.go +++ b/peridot/yumrepofs/v1/server.go @@ -88,9 +88,7 @@ func (s *Server) interceptor(ctx context.Context, req interface{}, usi *grpc.Una func (s *Server) Run() { res := utils.NewGRPCServer( &utils.GRPCOptions{ - ServerOptions: []grpc.ServerOption{ - grpc.UnaryInterceptor(s.interceptor), - }, + Interceptor: s.interceptor, }, func(r *utils.Register) { endpoints := []utils.GrpcEndpointRegister{ diff --git a/secparse/admin/impl/server.go b/secparse/admin/impl/server.go index bc1fd54..a1a6c95 100644 --- a/secparse/admin/impl/server.go +++ b/secparse/admin/impl/server.go @@ -79,9 +79,7 @@ func (s *Server) interceptor(ctx context.Context, req interface{}, usi *grpc.Una func (s *Server) Run() { res := utils.NewGRPCServer( &utils.GRPCOptions{ - ServerOptions: []grpc.ServerOption{ - grpc.UnaryInterceptor(s.interceptor), - }, + Interceptor: s.interceptor, }, func(r *utils.Register) { err := secparseadminpb.RegisterSecparseAdminHandlerFromEndpoint( diff --git a/secparse/impl/server.go b/secparse/impl/server.go index f54bf88..fe9a909 100644 --- a/secparse/impl/server.go +++ b/secparse/impl/server.go @@ -67,9 +67,7 @@ func (s *Server) interceptor(ctx context.Context, req interface{}, usi *grpc.Una func (s *Server) Run() { res := utils.NewGRPCServer( &utils.GRPCOptions{ - ServerOptions: []grpc.ServerOption{ - grpc.UnaryInterceptor(s.interceptor), - }, + Interceptor: s.interceptor, }, func(r *utils.Register) { err := secparsepb.RegisterSecparseHandlerFromEndpoint( diff --git a/utils/grpc.go b/utils/grpc.go index 1c00ae4..4aab8cf 100644 --- a/utils/grpc.go +++ b/utils/grpc.go @@ -32,6 +32,8 @@ package utils import ( "context" + "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" "net" @@ -42,6 +44,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + _ "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" "github.com/spf13/viper" "google.golang.org/grpc" @@ -100,12 +103,14 @@ func DefaultServeMuxOption() []runtime.ServeMuxOption { } type GRPCOptions struct { - DialOptions []grpc.DialOption - MuxOptions []runtime.ServeMuxOption - ServerOptions []grpc.ServerOption - DisableREST bool - DisableGRPC bool - Timeout *time.Duration + DialOptions []grpc.DialOption + MuxOptions []runtime.ServeMuxOption + ServerOptions []grpc.ServerOption + Interceptor grpc.UnaryServerInterceptor + ServerInterceptor grpc.StreamServerInterceptor + DisableREST bool + DisableGRPC bool + Timeout *time.Duration } type Register struct { @@ -160,7 +165,61 @@ func NewGRPCServer(goptions *GRPCOptions, endpoint func(*Register), serve func(* } opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1000*1024*1024), grpc.MaxCallSendMsgSize(1000*1024*1024))) - serv := grpc.NewServer(options.ServerOptions...) + serverOpts := options.ServerOptions + // If the server already declares a unary interceptor, let's chain + // and make grpc_prometheus run first + if options.Interceptor != nil { + serverOpts = append(serverOpts, grpc.UnaryInterceptor( + func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + n := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + return options.Interceptor(ctx, req, info, handler) + } + n = func(next grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, usi *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + _, err := grpc_prometheus.UnaryServerInterceptor(ctx, req, info, handler) + if err != nil { + return nil, err + } + + return next(ctx, req, usi, handler) + } + }(n) + + return n(ctx, req, info, handler) + }, + )) + } else { + // Else, only declare prometheus interceptor + serverOpts = append(serverOpts, grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor)) + } + + // If the server already declares a stream interceptor, let's chain + // and make grpc_prometheus run first + if options.ServerInterceptor != nil { + serverOpts = append(serverOpts, grpc.StreamInterceptor( + func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + n := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return options.ServerInterceptor(srv, ss, info, handler) + } + n = func(next grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := grpc_prometheus.StreamServerInterceptor(srv, ss, info, handler) + if err != nil { + return err + } + + return next(srv, ss, info, handler) + } + }(n) + + return n(srv, ss, info, handler) + }, + )) + } else { + // Else, only declare prometheus interceptor + serverOpts = append(serverOpts, grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor)) + } + serv := grpc.NewServer(serverOpts...) // background context since this is the "main" app ctx, cancel := context.WithCancel(context.TODO()) @@ -200,9 +259,9 @@ func NewGRPCServer(goptions *GRPCOptions, endpoint func(*Register), serve func(* r.Mount("/", mux) var wg sync.WaitGroup - wg.Add(2) if !options.DisableREST { + wg.Add(1) go func(wg *sync.WaitGroup) { logrus.Infof("starting http server on port %s", viper.GetString("api.port")) @@ -216,22 +275,36 @@ func NewGRPCServer(goptions *GRPCOptions, endpoint func(*Register), serve func(* } if !options.DisableGRPC { - logrus.Infof("starting grpc server on port %s", viper.GetString("grpc.port")) - registerServer := &RegisterServer{ - Server: serv, - } + wg.Add(1) + go func(wg *sync.WaitGroup) { + logrus.Infof("starting grpc server on port %s", viper.GetString("grpc.port")) + registerServer := &RegisterServer{ + Server: serv, + } - if serve != nil { - serve(registerServer) - } + if serve != nil { + serve(registerServer) + } + grpc_prometheus.Register(serv) - err = serv.Serve(lis) - if err != nil { - logrus.Fatalf("failed to serve: %v", err) - } - wg.Done() + err = serv.Serve(lis) + if err != nil { + logrus.Fatalf("failed to serve: %v", err) + } + wg.Done() + }(&wg) } + wg.Add(1) + go func(wg *sync.WaitGroup) { + promMux := http.NewServeMux() + promMux.Handle("/metrics", promhttp.Handler()) + err := http.ListenAndServe(":7332", promMux) + if err != nil { + logrus.Fatalf("could not start prometheus server - %s", err) + } + }(&wg) + return &GRPCServerRes{ Cancel: cancel, WaitGroup: &wg,