Base improvements

This commit is contained in:
Mustafa Gezen 2023-08-25 18:48:11 +02:00
parent de95a0bedb
commit dffec12bd9
3 changed files with 428 additions and 50 deletions

View File

@ -33,7 +33,6 @@ const (
EnvVarFrontendRequiredOIDCGroup EnvVar = "FRONTEND_REQUIRED_OIDC_GROUP"
EnvVarTemporalNamespace EnvVar = "TEMPORAL_NAMESPACE"
EnvVarTemporalAddress EnvVar = "TEMPORAL_ADDRESS"
EnvVarFrontendAllowUnauthenticated EnvVar = "FRONTEND_ALLOW_UNAUTHENTICATED"
EnvVarFrontendSelf EnvVar = "FRONTEND_SELF"
)
@ -64,27 +63,38 @@ var defaultCliFlagsTemporal = append(defaultCliFlagsDatabaseOnly, []cli.Flag{
},
}...)
var defaultCliFlags = append(defaultCliFlagsDatabaseOnly, []cli.Flag{
var defaultCliFlagsNoAuth = append(defaultCliFlagsDatabaseOnly, []cli.Flag{
&cli.IntFlag{
Name: "grpc-port",
Aliases: []string{"p"},
Usage: "gRPC port",
EnvVars: []string{string(EnvVarGRPCPort)},
Value: 8080,
},
&cli.IntFlag{
Name: "gateway-port",
Aliases: []string{"g"},
Usage: "gRPC gateway port",
EnvVars: []string{string(EnvVarGatewayPort)},
Value: 8081,
},
}...)
var defaultCliFlags = append(defaultCliFlagsNoAuth, []cli.Flag{
&cli.StringFlag{
Name: "oidc-issuer",
Usage: "OIDC issuer",
EnvVars: []string{string(EnvVarFrontendOIDCIssuer)},
Value: "https://accounts.rockylinux.org/auth/realms/rocky",
},
&cli.StringFlag{
Name: "required-oidc-group",
Usage: "OIDC group that is required to access the frontend",
EnvVars: []string{string(EnvVarFrontendRequiredOIDCGroup)},
},
}...)
var defaultFrontendNoAuthCliFlags = []cli.Flag{
&cli.IntFlag{
Name: "port",
Aliases: []string{"p"},
Usage: "frontend port",
EnvVars: []string{string(EnvVarFrontendPort)},
Value: 9111,
@ -118,11 +128,6 @@ var defaultFrontendCliFlags = append(defaultFrontendNoAuthCliFlags, []cli.Flag{
Usage: "OIDC group that is required to access the frontend",
EnvVars: []string{string(EnvVarFrontendRequiredOIDCGroup)},
},
&cli.StringFlag{
Name: "allow-unauthenticated",
Usage: "Allow unauthenticated access to the frontend",
EnvVars: []string{string(EnvVarFrontendAllowUnauthenticated)},
},
&cli.StringFlag{
Name: "self",
Usage: "Endpoint pointing to the frontend",
@ -135,6 +140,11 @@ func WithDefaultCliFlags(flags ...cli.Flag) []cli.Flag {
return append(defaultCliFlags, flags...)
}
// WithDefaultCliFlagsNoAuth adds the default cli flags to the app.
func WithDefaultCliFlagsNoAuth(flags ...cli.Flag) []cli.Flag {
return append(defaultCliFlagsNoAuth, flags...)
}
// WithDefaultCliFlagsTemporal adds the default cli flags to the app.
func WithDefaultCliFlagsTemporal(flags ...cli.Flag) []cli.Flag {
return append(defaultCliFlagsTemporal, flags...)
@ -167,8 +177,8 @@ func FlagsToGRPCServerOptions(ctx *cli.Context) []GRPCServerOption {
func FlagsToFrontendInfo(ctx *cli.Context) *FrontendInfo {
return &FrontendInfo{
Title: ctx.App.Name,
Port: ctx.Int("port"),
Self: ctx.String("self"),
AllowUnauthenticated: ctx.Bool("allow-unauthenticated"),
OIDCIssuer: ctx.String("oidc-issuer"),
OIDCClientID: ctx.String("oidc-client-id"),
OIDCClientSecret: ctx.String("oidc-client-secret"),
@ -177,6 +187,14 @@ func FlagsToFrontendInfo(ctx *cli.Context) *FrontendInfo {
}
}
// FlagsToOidcInterceptorDetails converts the cli flags to oidc interceptor details.
func FlagsToOidcInterceptorDetails(ctx *cli.Context) *OidcInterceptorDetails {
return &OidcInterceptorDetails{
Issuer: ctx.String("oidc-issuer"),
Group: ctx.String("required-oidc-group"),
}
}
// GetDBFromFlags gets the database from the cli flags.
func GetDBFromFlags(ctx *cli.Context) *DB {
db, err := NewDB(ctx.String("database-url"))

View File

@ -15,21 +15,39 @@
package base
import (
"context"
"crypto/sha256"
"embed"
"encoding/hex"
"errors"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"html/template"
"math/rand"
"mime"
"net/http"
"net/url"
"path/filepath"
"strconv"
"strings"
_ "embed"
)
type FrontendInfo struct {
// NoRun is a flag to disable running the frontend server
NoRun bool
// MuxHandler is the HTTP handler (can be nil)
MuxHandler http.Handler
// Title to add to the HTML page
Title string
// Port is the port to serve the frontend on
Port int
// Self is the URL to the frontend server
Self string
@ -52,12 +70,25 @@ type FrontendInfo struct {
OIDCGroup string
// OIDCUserInfoOverride is a flag to override the userinfo endpoint
// todo(mustafa): we don't need to use it yet since RESF deploys cluster external Keycloak.
OIDCUserInfoOverride string
// AdditionalContent is a map of paths to content to serve
AdditionalContent map[string][]byte
// Internal
// unauthenticatedTemplate is the template for the unauthenticated page
unauthenticatedTemplate string
}
type frontendTemplateData struct {
User *oidc.UserInfo
Prefix string
}
//go:embed assets/oh_no_unauthenticated.png
var ohNoGopher []byte
var frontendHtmlTemplate = `
<!DOCTYPE html>
<html>
@ -69,6 +100,8 @@ var frontendHtmlTemplate = `
/>
<title>{{.Title}}</title>
<link rel="icon" type="image/png" href="/_ga/favicon.png" />
<link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
<link
@ -79,10 +112,54 @@ var frontendHtmlTemplate = `
<body>
<div id="app"></div>
{{if .User}}
<script>
window.__peridot_user__ = {
sub: {{.User.Subject}},
email: {{.User.Email}},
}
</script>
{{end}}
{{if .Prefix}}
<script>window.__peridot_prefix__ = '{{.Prefix}}'.replace('\\', '');</script>
{{end}}
<script src="{{.BundleJS}}"></script>
</body>
</html>
`
var frontendUnauthenticated = `
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta
name="viewport"
content="width=device-width, initial-scale=1, viewport-fit=cover"
/>
<title>{{.Title}} - Unauthenticated</title>
<link rel="icon" type="image/png" href="/_ga/favicon.png" />
<link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;600;700&display=swap"
/>
</head>
<body>
<div style="font-family:Roboto,sans-serif;display:flex;flex-flow:column;justify-content:center;align-items:center;width:100vw;height:100vh">
<img height="120px" src="/_ga/oh_no_unauthenticated.png" /><br />
<code>{{.}}</code><br /><br />
<a href="/">Go back</a>
</div>
</body>
</html>
`
const frontendAuthCookieKey = "auth_bearer"
func readDir(embedfs *embed.FS, root string) ([]string, error) {
var paths []string
@ -104,32 +181,127 @@ func readDir(embedfs *embed.FS, root string) ([]string, error) {
return paths, nil
}
func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
port := 9111
func (info *FrontendInfo) renderUnauthorized(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "text/html")
tmpl, err := template.New("unauthorized.html").Parse(info.unauthenticatedTemplate)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = tmpl.Execute(w, message)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// frontendAuthHandler verifies that the user is authenticated
// if not redirects to /auth/oidc/login
func (info *FrontendInfo) frontendAuthHandler(provider *oidc.Provider, h http.Handler) http.Handler {
excludedSuffixes := []string{
"/auth/oidc/login",
"/auth/oidc/callback",
"/_ga/oh_no_unauthenticated.png",
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// if the path is excluded, then serve the request
for _, suffix := range excludedSuffixes {
if strings.HasSuffix(r.URL.Path, suffix) {
h.ServeHTTP(w, r)
return
}
}
// get auth cookie
authCookie, err := r.Cookie(frontendAuthCookieKey)
if err != nil {
// redirect to login
http.Redirect(w, r, info.Self+"/auth/oidc/login", http.StatusFound)
return
}
// verify the token
userInfo, err := provider.UserInfo(r.Context(), oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: authCookie.Value,
TokenType: "Bearer",
}))
if err != nil {
// redirect to login
http.Redirect(w, r, info.Self+"/auth/oidc/login", http.StatusFound)
return
}
// Check if the user is in the group
var claims oidcClaims
err = userInfo.Claims(&claims)
if err != nil {
// redirect to login
http.Redirect(w, r, info.Self+"/auth/oidc/login", http.StatusFound)
return
}
groups := claims.Groups
if info.OIDCGroup != "" {
if !Contains(groups, info.OIDCGroup) {
// show unauthenticated page
info.renderUnauthorized(w, fmt.Sprintf("User is not in group %s", info.OIDCGroup))
return
}
}
// Add the user to the context
ctx := context.WithValue(r.Context(), "user", userInfo)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
func FrontendServer(info *FrontendInfo, embedfs *embed.FS) error {
if info == nil {
info = &FrontendInfo{}
return errors.New("frontend info is nil")
}
port := info.Port
// Using info.Self, let's determine a path prefix
// If info.Self is empty, then we'll use the root path
prefix := ""
if info.Self != "" {
parsed, err := url.Parse(info.Self)
if err != nil {
return fmt.Errorf("failed to parse self url: %w", err)
}
prefix = parsed.Path
}
newTemplate := frontendHtmlTemplate
newUnauthenticatedTemplate := frontendUnauthenticated
// Set the title
if info.Title == "" {
info.Title = "Peridot"
}
newTemplate = strings.ReplaceAll(newTemplate, "{{.Title}}", info.Title)
newUnauthenticatedTemplate = strings.ReplaceAll(newUnauthenticatedTemplate, "{{.Title}}", info.Title)
pathToContent := map[string][]byte{}
info.unauthenticatedTemplate = newUnauthenticatedTemplate
pathToContent := map[string][]byte{
prefix + "/_ga/oh_no_unauthenticated.png": ohNoGopher,
}
// Read the files from the embedfs
paths, err := readDir(embedfs, "bundle")
if err != nil {
LogFatalf("failed to read embedfs: %v", err)
return fmt.Errorf("failed to read embedfs: %w", err)
}
for _, path := range paths {
content, err := embedfs.ReadFile(path)
if err != nil {
LogFatalf("failed to read embedfs: %v", err)
return fmt.Errorf("failed to read embedfs: %w", err)
}
// Sha256 hash of the content to add to name
@ -139,7 +311,7 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
ext := filepath.Ext(path)
noExtName := path[:len(path)-len(ext)]
newPath := fmt.Sprintf("/_ga/%s.%s%s", noExtName, hashSum[:8], ext)
newPath := fmt.Sprintf("%s/_ga/%s.%s%s", prefix, noExtName, hashSum[:8], ext)
pathToContent[newPath] = content
@ -163,7 +335,7 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
}
// Serve the content
http.HandleFunc("/_ga/", func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc(prefix+"/_ga/", func(w http.ResponseWriter, r *http.Request) {
mimeType := mime.TypeByExtension(filepath.Ext(r.URL.Path))
if mimeType == "" {
mimeType = "application/octet-stream"
@ -172,24 +344,174 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
_, _ = w.Write([]byte(pathToContent[r.URL.Path]))
})
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc(prefix+"/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
_, _ = w.Write([]byte(newTemplate))
tmpl, err := template.New("index.html").Parse(newTemplate)
if err != nil {
info.renderUnauthorized(w, fmt.Sprintf("Failed to parse template: %v", err))
return
}
// If user is in context, then add it to the template
data := frontendTemplateData{
Prefix: prefix,
}
user := r.Context().Value("user")
if user != nil {
data.User = user.(*oidc.UserInfo)
}
err = tmpl.Execute(w, data)
if err != nil {
info.renderUnauthorized(w, fmt.Sprintf("Failed to execute template: %v", err))
return
}
})
// Handle other _ga meta routes
http.HandleFunc("/_ga/healthz", func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc(prefix+"/_ga/healthz", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte("ok"))
})
// Handle auth routes
http.HandleFunc("/auth/oidc/login", func(w http.ResponseWriter, r *http.Request) {
var provider *oidc.Provider
if !info.NoAuth {
ctx := context.TODO()
provider, err = oidc.NewProvider(ctx, info.OIDCIssuer)
if err != nil {
return fmt.Errorf("failed to create oidc provider: %w", err)
}
})
redirectURL := info.Self + "/auth/oidc/callback"
LogInfof("starting frontend server on port %d", port)
if err := http.ListenAndServe(":"+strconv.Itoa(port), nil); err != nil {
LogFatalf("failed to start frontend server: %v", err)
oauth2Config := oauth2.Config{
ClientID: info.OIDCClientID,
ClientSecret: info.OIDCClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: redirectURL,
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "groups"},
}
http.HandleFunc(prefix+"/auth/oidc/login", func(w http.ResponseWriter, r *http.Request) {
// Generate a random state
state := ""
for i := 0; i < 16; i++ {
state += strconv.Itoa(rand.Intn(10))
}
// Generate the auth url
authURL := oauth2Config.AuthCodeURL(state)
// Set the state cookie
http.SetCookie(w, &http.Cookie{
Name: "auth_state",
Value: state,
Path: "/",
// expires in 2 minutes
MaxAge: 120,
// secure if self is https
Secure: strings.HasPrefix(info.Self, "https://"),
})
// Redirect to the auth url
http.Redirect(w, r, authURL, http.StatusFound)
})
http.HandleFunc(prefix+"/auth/oidc/callback", func(w http.ResponseWriter, r *http.Request) {
// Get the state cookie
stateCookie, err := r.Cookie("auth_state")
if err != nil {
info.renderUnauthorized(w, fmt.Sprintf("Failed to get state cookie: %v", err))
return
}
// Get the state query param
stateQueryParam := r.URL.Query().Get("state")
if stateQueryParam == "" {
info.renderUnauthorized(w, "No state parameter in query")
return
}
// Check if the state cookie and state query param match
if stateCookie.Value != stateQueryParam {
info.renderUnauthorized(w, "State cookie and state query param do not match")
return
}
// Exchange the code for a token
token, err := oauth2Config.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
info.renderUnauthorized(w, fmt.Sprintf("Failed to exchange code: %v", err))
return
}
// Verify the token
accessToken := token.AccessToken
userInfo, err := provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token))
if err != nil {
info.renderUnauthorized(w, fmt.Sprintf("Failed to get userinfo: %v", err))
return
}
// Check if the user is in the group
if info.OIDCGroup != "" {
var claims oidcClaims
err := userInfo.Claims(&claims)
if err != nil {
info.renderUnauthorized(w, fmt.Sprintf("Failed to get claims: %v", err))
return
}
groups := claims.Groups
found := false
for _, group := range groups {
if group == info.OIDCGroup {
found = true
break
}
}
if !found {
info.renderUnauthorized(w, fmt.Sprintf("User is not in group %s", info.OIDCGroup))
return
}
}
// Set the auth cookie
http.SetCookie(w, &http.Cookie{
Name: frontendAuthCookieKey,
Value: accessToken,
Path: "/",
// expires in 2 hours
MaxAge: 7200,
// secure if self is https
Secure: strings.HasPrefix(info.Self, "https://"),
})
// Redirect to self, this is due to the "root" not being / for all apps
http.Redirect(w, r, info.Self, http.StatusFound)
})
}
var handler http.Handler = nil
// if auth is enabled as well as AllowUnauthenticated is false, then wrap the handler with the auth handler
if !info.NoAuth && !info.AllowUnauthenticated {
handler = info.frontendAuthHandler(provider, http.DefaultServeMux)
} else {
handler = http.DefaultServeMux
}
info.MuxHandler = handler
if !info.NoRun {
LogInfof("starting frontend server on port %d", port)
if err := http.ListenAndServe(":"+strconv.Itoa(port), handler); err != nil {
return fmt.Errorf("failed to start frontend server: %w", err)
}
}
return nil
}

View File

@ -16,16 +16,17 @@ package base
import (
"context"
"errors"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/prometheus/client_golang/prometheus/promhttp"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/encoding/protojson"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
)
@ -47,6 +48,7 @@ type GRPCServer struct {
grpcPort int
gatewayPort int
noGrpcGateway bool
noMetrics bool
}
type GrpcEndpointRegister func(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error
@ -129,6 +131,39 @@ func WithNoGRPCGateway() GRPCServerOption {
}
}
// WithNoMetrics disables the Prometheus metrics for the gRPC server.
func WithNoMetrics() GRPCServerOption {
return func(g *GRPCServer) {
g.noMetrics = true
}
}
func DefaultServeMuxOptions() []runtime.ServeMuxOption {
return []runtime.ServeMuxOption{
runtime.WithIncomingHeaderMatcher(func(s string) (string, bool) {
switch strings.ToLower(s) {
case "authorization",
"cookie":
return s, true
}
if strings.ToLower(s) == "content-type" {
return "original-content-type", true
}
return s, false
}),
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: false,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
}),
}
}
// NewGRPCServer creates a new gRPC-server with gRPC-gateway, default interceptors
// and exposed Prometheus metrics.
func NewGRPCServer(opts ...GRPCServerOption) (*GRPCServer, error) {
@ -157,6 +192,9 @@ func NewGRPCServer(opts ...GRPCServerOption) (*GRPCServer, error) {
if g.gatewayPort == 0 {
g.gatewayPort = g.grpcPort + 1
}
if len(g.muxOptions) == 0 {
g.muxOptions = DefaultServeMuxOptions()
}
// Always prepend the insecure dial option
// RESF deploys with Istio, which handles mTLS
@ -180,15 +218,13 @@ func NewGRPCServer(opts ...GRPCServerOption) (*GRPCServer, error) {
g.server = grpc.NewServer(g.serverOptions...)
if !g.noGrpcGateway {
g.gatewayMux = runtime.NewServeMux(g.muxOptions...)
g.gatewayMux = runtime.NewServeMux(g.muxOptions...)
// Create gateway client connection
var err error
g.gatewayClientConn, err = grpc.Dial("localhost:"+strconv.Itoa(g.grpcPort), g.dialOptions...)
if err != nil {
return nil, err
}
// Create gateway client connection
var err error
g.gatewayClientConn, err = grpc.Dial("localhost:"+strconv.Itoa(g.grpcPort), g.dialOptions...)
if err != nil {
return nil, err
}
return g, nil
@ -199,10 +235,6 @@ func (g *GRPCServer) RegisterService(register func(*grpc.Server)) {
}
func (g *GRPCServer) GatewayEndpoints(registerEndpoints ...GrpcEndpointRegister) error {
if g.noGrpcGateway {
return errors.New("gRPC-gateway is disabled")
}
for _, register := range registerEndpoints {
if err := register(context.Background(), g.gatewayMux, g.gatewayClientConn); err != nil {
return err
@ -212,6 +244,10 @@ func (g *GRPCServer) GatewayEndpoints(registerEndpoints ...GrpcEndpointRegister)
return nil
}
func (g *GRPCServer) GatewayMux() *runtime.ServeMux {
return g.gatewayMux
}
func (g *GRPCServer) Start() error {
// Create gRPC listener
grpcLis, err := net.Listen("tcp", ":"+strconv.Itoa(g.grpcPort))
@ -254,17 +290,19 @@ func (g *GRPCServer) Start() error {
}
// Serve proxmux
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
if !g.noMetrics {
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
promMux := http.NewServeMux()
promMux.Handle("/metrics", promhttp.Handler())
err := http.ListenAndServe(":7332", promMux)
if err != nil {
LogFatalf("Prometheus mux failed to serve: %v", err.Error())
}
}(&wg)
promMux := http.NewServeMux()
promMux.Handle("/metrics", promhttp.Handler())
err := http.ListenAndServe(":7332", promMux)
if err != nil {
LogFatalf("Prometheus mux failed to serve: %v", err.Error())
}
}(&wg)
}
wg.Wait()