From cdef680a944b8bd4fa7f5bf68154f32b9e1e8cdd Mon Sep 17 00:00:00 2001 From: Mustafa Gezen Date: Fri, 25 Aug 2023 20:54:49 +0200 Subject: [PATCH] Switch Oidc.Provider to a generic interface to allow testing --- base/go/BUILD | 2 + base/go/auth.go | 116 +++++++++++++++++- base/go/flags.go | 14 ++- base/go/frontend_server.go | 14 ++- .../mothership/cmd/mship_admin_server/main.go | 5 +- tools/mothership/cmd/mship_dev/main.go | 5 +- 6 files changed, 140 insertions(+), 16 deletions(-) diff --git a/base/go/BUILD b/base/go/BUILD index e037a834..f3de7f8f 100644 --- a/base/go/BUILD +++ b/base/go/BUILD @@ -21,6 +21,7 @@ go_library( "db.go", "flags.go", "frontend_server.go", + "fs.go", "grpc.go", "log.go", "pb.go", @@ -37,6 +38,7 @@ go_library( "//vendor/github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth", "//vendor/github.com/grpc-ecosystem/go-grpc-prometheus", "//vendor/github.com/grpc-ecosystem/grpc-gateway/v2/runtime", + "//vendor/github.com/pkg/errors", "//vendor/github.com/prometheus/client_golang/prometheus/promhttp", "//vendor/github.com/urfave/cli/v2:cli", "//vendor/github.com/wk8/go-ordered-map/v2:go-ordered-map", diff --git a/base/go/auth.go b/base/go/auth.go index e42598e0..0a72e232 100644 --- a/base/go/auth.go +++ b/base/go/auth.go @@ -2,6 +2,7 @@ package base import ( "context" + "errors" "github.com/coreos/go-oidc/v3/oidc" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" "golang.org/x/oauth2" @@ -14,23 +15,76 @@ import ( const UserContextKey = "user" +// OidcInterceptorDetails contains the details for the OIDC interceptor type OidcInterceptorDetails struct { - Issuer string + Provider OidcProvider Group string AllowUnauthenticated bool } +// oidcClaims contains the claims for the OIDC token +// At least the ones we care for at the moment type oidcClaims struct { Groups []string } -func OidcGrpcInterceptor(details *OidcInterceptorDetails) (grpc.UnaryServerInterceptor, error) { - ctx := context.TODO() - provider, err := oidc.NewProvider(ctx, details.Issuer) +// OidcProvider is the interface for OIDC providers +type OidcProvider interface { + UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (UserInfo, error) +} + +// UserInfo is the interface for user info +type UserInfo interface { + Subject() string + Email() string + Claims(v interface{}) error +} + +// OidcProviderImpl is the implementation of OidcProvider +// This is main usage in "real" applications +// Tests should use the TestOidcProvider +type OidcProviderImpl struct { + *oidc.Provider +} + +// UserInfo gets the user info from the OIDC provider +func (o *OidcProviderImpl) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (UserInfo, error) { + userInfo, err := o.Provider.UserInfo(ctx, tokenSource) if err != nil { return nil, err } + return &OidcUserInfo{userInfo}, nil +} + +// OidcUserInfo is the implementation of UserInfo +type OidcUserInfo struct { + UserInfo *oidc.UserInfo +} + +// Subject gets the subject from the user info +func (o *OidcUserInfo) Subject() string { + return o.UserInfo.Subject +} + +// Email gets the email from the user info +func (o *OidcUserInfo) Email() string { + return o.UserInfo.Email +} + +// Claims gets the claims from the user info +func (o *OidcUserInfo) Claims(v interface{}) error { + return o.UserInfo.Claims(v) +} + +// OidcGrpcInterceptor creates a new OIDC interceptor +// This enforces authentication and authorization +// Authorization is as simple as checking if the user is in a group +// If the group is empty, no authorization is enforced +// Authentication enforcement can be disabled by setting AllowUnauthenticated to true +func OidcGrpcInterceptor(details *OidcInterceptorDetails) (grpc.UnaryServerInterceptor, error) { + provider := details.Provider + interceptor := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { @@ -95,3 +149,57 @@ func OidcGrpcInterceptor(details *OidcInterceptorDetails) (grpc.UnaryServerInter return interceptor, nil } + +// TestOidcProvider is a test implementation of OidcProvider +type TestOidcProvider struct { + // This interface is a pointer on purpose, so we can point it to + // a value in main_test and change it in the tests + userInfo *UserInfo +} + +// NewTestOidcProvider creates a new TestOidcProvider +func NewTestOidcProvider(userInfo *UserInfo) *TestOidcProvider { + return &TestOidcProvider{ + userInfo: userInfo, + } +} + +// UserInfo gets the user info from the OIDC provider +func (t *TestOidcProvider) UserInfo(_ context.Context, _ oauth2.TokenSource) (UserInfo, error) { + if t.userInfo == nil { + return nil, errors.New("no user info") + } + return *t.userInfo, nil +} + +// TestUserInfo is a test implementation of UserInfo +type TestUserInfo struct { + subject string + email string + claims map[string]any +} + +// NewTestUserInfo creates a new TestUserInfo +func NewTestUserInfo(subject string, email string, claims map[string]any) *TestUserInfo { + return &TestUserInfo{ + subject: subject, + email: email, + claims: claims, + } +} + +// Subject gets the subject from the user info +func (t *TestUserInfo) Subject() string { + return t.subject +} + +// Email gets the email from the user info +func (t *TestUserInfo) Email() string { + return t.email +} + +// Claims gets the claims from the user info +func (t *TestUserInfo) Claims(v *any) error { + *v = t.claims + return nil +} diff --git a/base/go/flags.go b/base/go/flags.go index b46e1875..c139ec52 100644 --- a/base/go/flags.go +++ b/base/go/flags.go @@ -15,6 +15,7 @@ package base import ( + "github.com/coreos/go-oidc/v3/oidc" "github.com/urfave/cli/v2" "os" ) @@ -188,11 +189,16 @@ func FlagsToFrontendInfo(ctx *cli.Context) *FrontendInfo { } // FlagsToOidcInterceptorDetails converts the cli flags to oidc interceptor details. -func FlagsToOidcInterceptorDetails(ctx *cli.Context) *OidcInterceptorDetails { - return &OidcInterceptorDetails{ - Issuer: ctx.String("oidc-issuer"), - Group: ctx.String("required-oidc-group"), +func FlagsToOidcInterceptorDetails(ctx *cli.Context) (*OidcInterceptorDetails, error) { + provider, err := oidc.NewProvider(ctx.Context, ctx.String("oidc-issuer")) + if err != nil { + return nil, err } + + return &OidcInterceptorDetails{ + Provider: &OidcProviderImpl{provider}, + Group: ctx.String("required-oidc-group"), + }, nil } // GetDBFromFlags gets the database from the cli flags. diff --git a/base/go/frontend_server.go b/base/go/frontend_server.go index 0d1a6afa..43786917 100644 --- a/base/go/frontend_server.go +++ b/base/go/frontend_server.go @@ -83,7 +83,7 @@ type FrontendInfo struct { } type frontendTemplateData struct { - User *oidc.UserInfo + User UserInfo Prefix string } @@ -199,7 +199,7 @@ func (info *FrontendInfo) renderUnauthorized(w http.ResponseWriter, message stri // frontendAuthHandler verifies that the user is authenticated // if not redirects to /auth/oidc/login -func (info *FrontendInfo) frontendAuthHandler(provider *oidc.Provider, h http.Handler) http.Handler { +func (info *FrontendInfo) frontendAuthHandler(provider OidcProvider, h http.Handler) http.Handler { excludedSuffixes := []string{ "/auth/oidc/login", "/auth/oidc/callback", @@ -360,7 +360,7 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) error { user := r.Context().Value("user") if user != nil { - data.User = user.(*oidc.UserInfo) + data.User = user.(UserInfo) } err = tmpl.Execute(w, data) @@ -377,20 +377,22 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) error { }) // Handle auth routes - var provider *oidc.Provider + var provider OidcProvider if !info.NoAuth { ctx := context.TODO() - provider, err = oidc.NewProvider(ctx, info.OIDCIssuer) + provider2, err := oidc.NewProvider(ctx, info.OIDCIssuer) if err != nil { return fmt.Errorf("failed to create oidc provider: %w", err) } + provider = &OidcProviderImpl{provider2} + redirectURL := info.Self + "/auth/oidc/callback" oauth2Config := oauth2.Config{ ClientID: info.OIDCClientID, ClientSecret: info.OIDCClientSecret, - Endpoint: provider.Endpoint(), + Endpoint: provider2.Endpoint(), RedirectURL: redirectURL, Scopes: []string{oidc.ScopeOpenID, "profile", "email", "groups"}, } diff --git a/tools/mothership/cmd/mship_admin_server/main.go b/tools/mothership/cmd/mship_admin_server/main.go index ea768953..6cdba414 100644 --- a/tools/mothership/cmd/mship_admin_server/main.go +++ b/tools/mothership/cmd/mship_admin_server/main.go @@ -22,7 +22,10 @@ import ( ) func run(ctx *cli.Context) error { - oidcInterceptorDetails := base.FlagsToOidcInterceptorDetails(ctx) + oidcInterceptorDetails, err := base.FlagsToOidcInterceptorDetails(ctx) + if err != nil { + return err + } s, err := mothershipadmin_rpc.NewServer( base.GetDBFromFlags(ctx), diff --git a/tools/mothership/cmd/mship_dev/main.go b/tools/mothership/cmd/mship_dev/main.go index 59853a4d..7364a2b0 100644 --- a/tools/mothership/cmd/mship_dev/main.go +++ b/tools/mothership/cmd/mship_dev/main.go @@ -70,7 +70,10 @@ func setupApi(ctx *cli.Context) (*runtime.ServeMux, error) { } func setupAdminApi(ctx *cli.Context) (*runtime.ServeMux, error) { - oidcInterceptorDetails := base.FlagsToOidcInterceptorDetails(ctx) + oidcInterceptorDetails, err := base.FlagsToOidcInterceptorDetails(ctx) + if err != nil { + return nil, err + } s, err := mothershipadmin_rpc.NewServer( base.GetDBFromFlags(ctx),