peridot/vendor/github.com/authzed/grpcutil/middleware.go

119 lines
4.2 KiB
Go
Raw Normal View History

2022-07-07 20:11:50 +00:00
package grpcutil
import (
"context"
"fmt"
"strings"
2022-07-07 20:11:50 +00:00
grpcmw "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/validator"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)
// IgnoreAuthMixin is a struct that can be embedded to make a gRPC handler
// ignore any auth requirements set by the gRPC community auth middleware.
type IgnoreAuthMixin struct{}
var _ grpc_auth.ServiceAuthFuncOverride = (*IgnoreAuthMixin)(nil)
// AuthFuncOverride implements the grpc_auth.ServiceAuthFuncOverride by
// performing a no-op.
func (m IgnoreAuthMixin) AuthFuncOverride(ctx context.Context, _ string) (context.Context, error) {
2022-07-07 20:11:50 +00:00
return ctx, nil
}
// AuthlessHealthServer implements a gRPC health endpoint that will ignore any auth
// requirements set by github.com/grpc-ecosystem/go-grpc-middleware/auth.
type AuthlessHealthServer struct {
*health.Server
IgnoreAuthMixin
}
// NewAuthlessHealthServer returns a new gRPC health server that ignores auth
// middleware.
func NewAuthlessHealthServer() *AuthlessHealthServer {
return &AuthlessHealthServer{Server: health.NewServer()}
}
// SetServicesHealthy sets the service to SERVING
func (s *AuthlessHealthServer) SetServicesHealthy(svcDesc ...*grpc.ServiceDesc) {
for _, d := range svcDesc {
s.SetServingStatus(
d.ServiceName,
healthpb.HealthCheckResponse_SERVING,
)
}
}
// DefaultUnaryMiddleware is a recommended set of middleware that should each gracefully no-op if the middleware is not
// applicable.
var DefaultUnaryMiddleware = []grpc.UnaryServerInterceptor{grpcvalidate.UnaryServerInterceptor()}
// WrapMethods wraps all non-streaming endpoints with the given list of interceptors.
// It returns a copy of the ServiceDesc with the new wrapped methods.
func WrapMethods(svcDesc grpc.ServiceDesc, interceptors ...grpc.UnaryServerInterceptor) (wrapped *grpc.ServiceDesc) {
chain := grpcmw.ChainUnaryServer(interceptors...)
for i, m := range svcDesc.Methods {
handler := m.Handler
wrapped := grpc.MethodDesc{
MethodName: m.MethodName,
Handler: func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
if interceptor == nil {
interceptor = NoopUnaryInterceptor
}
return handler(srv, ctx, dec, grpcmw.ChainUnaryServer(interceptor, chain))
},
}
svcDesc.Methods[i] = wrapped
}
return &svcDesc
}
// WrapStreams wraps all streaming endpoints with the given list of interceptors.
// It returns a copy of the ServiceDesc with the new wrapped methods.
func WrapStreams(svcDesc grpc.ServiceDesc, interceptors ...grpc.StreamServerInterceptor) (wrapped *grpc.ServiceDesc) {
chain := grpcmw.ChainStreamServer(interceptors...)
for i, s := range svcDesc.Streams {
handler := s.Handler
info := &grpc.StreamServerInfo{
FullMethod: fmt.Sprintf("/%s/%s", svcDesc.ServiceName, s.StreamName),
IsClientStream: s.ClientStreams,
IsServerStream: s.ServerStreams,
}
wrapped := grpc.StreamDesc{
StreamName: s.StreamName,
ClientStreams: s.ClientStreams,
ServerStreams: s.ServerStreams,
Handler: func(srv interface{}, stream grpc.ServerStream) error {
return chain(srv, stream, info, handler)
},
}
svcDesc.Streams[i] = wrapped
}
return &svcDesc
}
// NoopUnaryInterceptor is a gRPC middleware that does not do anything.
func NoopUnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
2022-07-07 20:11:50 +00:00
return handler(ctx, req)
}
// SplitMethodName is used to split service name and method name from the
// method string passed into Interceptors.
//
// This function is vendored from:
// https://github.com/grpc-ecosystem/go-grpc-prometheus/blob/82c243799c991a7d5859215fba44a81834a52a71/util.go#L31-L37
//
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// Apache 2.0 Licensed
func SplitMethodName(fullMethodName string) (string, string) {
fullMethodName = strings.TrimPrefix(fullMethodName, "/") // remove leading slash
if i := strings.Index(fullMethodName, "/"); i >= 0 {
return fullMethodName[:i], fullMethodName[i+1:]
}
return "unknown", "unknown"
}