mirror of
https://github.com/peridotbuild/peridot.git
synced 2024-12-14 07:08:30 +00:00
Base improvements
This commit is contained in:
parent
de95a0bedb
commit
dffec12bd9
@ -33,7 +33,6 @@ const (
|
||||
EnvVarFrontendRequiredOIDCGroup EnvVar = "FRONTEND_REQUIRED_OIDC_GROUP"
|
||||
EnvVarTemporalNamespace EnvVar = "TEMPORAL_NAMESPACE"
|
||||
EnvVarTemporalAddress EnvVar = "TEMPORAL_ADDRESS"
|
||||
EnvVarFrontendAllowUnauthenticated EnvVar = "FRONTEND_ALLOW_UNAUTHENTICATED"
|
||||
EnvVarFrontendSelf EnvVar = "FRONTEND_SELF"
|
||||
)
|
||||
|
||||
@ -64,27 +63,38 @@ var defaultCliFlagsTemporal = append(defaultCliFlagsDatabaseOnly, []cli.Flag{
|
||||
},
|
||||
}...)
|
||||
|
||||
var defaultCliFlags = append(defaultCliFlagsDatabaseOnly, []cli.Flag{
|
||||
var defaultCliFlagsNoAuth = append(defaultCliFlagsDatabaseOnly, []cli.Flag{
|
||||
&cli.IntFlag{
|
||||
Name: "grpc-port",
|
||||
Aliases: []string{"p"},
|
||||
Usage: "gRPC port",
|
||||
EnvVars: []string{string(EnvVarGRPCPort)},
|
||||
Value: 8080,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "gateway-port",
|
||||
Aliases: []string{"g"},
|
||||
Usage: "gRPC gateway port",
|
||||
EnvVars: []string{string(EnvVarGatewayPort)},
|
||||
Value: 8081,
|
||||
},
|
||||
}...)
|
||||
|
||||
var defaultCliFlags = append(defaultCliFlagsNoAuth, []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "oidc-issuer",
|
||||
Usage: "OIDC issuer",
|
||||
EnvVars: []string{string(EnvVarFrontendOIDCIssuer)},
|
||||
Value: "https://accounts.rockylinux.org/auth/realms/rocky",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "required-oidc-group",
|
||||
Usage: "OIDC group that is required to access the frontend",
|
||||
EnvVars: []string{string(EnvVarFrontendRequiredOIDCGroup)},
|
||||
},
|
||||
}...)
|
||||
|
||||
var defaultFrontendNoAuthCliFlags = []cli.Flag{
|
||||
&cli.IntFlag{
|
||||
Name: "port",
|
||||
Aliases: []string{"p"},
|
||||
Usage: "frontend port",
|
||||
EnvVars: []string{string(EnvVarFrontendPort)},
|
||||
Value: 9111,
|
||||
@ -118,11 +128,6 @@ var defaultFrontendCliFlags = append(defaultFrontendNoAuthCliFlags, []cli.Flag{
|
||||
Usage: "OIDC group that is required to access the frontend",
|
||||
EnvVars: []string{string(EnvVarFrontendRequiredOIDCGroup)},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "allow-unauthenticated",
|
||||
Usage: "Allow unauthenticated access to the frontend",
|
||||
EnvVars: []string{string(EnvVarFrontendAllowUnauthenticated)},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "self",
|
||||
Usage: "Endpoint pointing to the frontend",
|
||||
@ -135,6 +140,11 @@ func WithDefaultCliFlags(flags ...cli.Flag) []cli.Flag {
|
||||
return append(defaultCliFlags, flags...)
|
||||
}
|
||||
|
||||
// WithDefaultCliFlagsNoAuth adds the default cli flags to the app.
|
||||
func WithDefaultCliFlagsNoAuth(flags ...cli.Flag) []cli.Flag {
|
||||
return append(defaultCliFlagsNoAuth, flags...)
|
||||
}
|
||||
|
||||
// WithDefaultCliFlagsTemporal adds the default cli flags to the app.
|
||||
func WithDefaultCliFlagsTemporal(flags ...cli.Flag) []cli.Flag {
|
||||
return append(defaultCliFlagsTemporal, flags...)
|
||||
@ -167,8 +177,8 @@ func FlagsToGRPCServerOptions(ctx *cli.Context) []GRPCServerOption {
|
||||
func FlagsToFrontendInfo(ctx *cli.Context) *FrontendInfo {
|
||||
return &FrontendInfo{
|
||||
Title: ctx.App.Name,
|
||||
Port: ctx.Int("port"),
|
||||
Self: ctx.String("self"),
|
||||
AllowUnauthenticated: ctx.Bool("allow-unauthenticated"),
|
||||
OIDCIssuer: ctx.String("oidc-issuer"),
|
||||
OIDCClientID: ctx.String("oidc-client-id"),
|
||||
OIDCClientSecret: ctx.String("oidc-client-secret"),
|
||||
@ -177,6 +187,14 @@ func FlagsToFrontendInfo(ctx *cli.Context) *FrontendInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// FlagsToOidcInterceptorDetails converts the cli flags to oidc interceptor details.
|
||||
func FlagsToOidcInterceptorDetails(ctx *cli.Context) *OidcInterceptorDetails {
|
||||
return &OidcInterceptorDetails{
|
||||
Issuer: ctx.String("oidc-issuer"),
|
||||
Group: ctx.String("required-oidc-group"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetDBFromFlags gets the database from the cli flags.
|
||||
func GetDBFromFlags(ctx *cli.Context) *DB {
|
||||
db, err := NewDB(ctx.String("database-url"))
|
||||
|
@ -15,21 +15,39 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"embed"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"html/template"
|
||||
"math/rand"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
type FrontendInfo struct {
|
||||
// NoRun is a flag to disable running the frontend server
|
||||
NoRun bool
|
||||
|
||||
// MuxHandler is the HTTP handler (can be nil)
|
||||
MuxHandler http.Handler
|
||||
|
||||
// Title to add to the HTML page
|
||||
Title string
|
||||
|
||||
// Port is the port to serve the frontend on
|
||||
Port int
|
||||
|
||||
// Self is the URL to the frontend server
|
||||
Self string
|
||||
|
||||
@ -52,12 +70,25 @@ type FrontendInfo struct {
|
||||
OIDCGroup string
|
||||
|
||||
// OIDCUserInfoOverride is a flag to override the userinfo endpoint
|
||||
// todo(mustafa): we don't need to use it yet since RESF deploys cluster external Keycloak.
|
||||
OIDCUserInfoOverride string
|
||||
|
||||
// AdditionalContent is a map of paths to content to serve
|
||||
AdditionalContent map[string][]byte
|
||||
|
||||
// Internal
|
||||
|
||||
// unauthenticatedTemplate is the template for the unauthenticated page
|
||||
unauthenticatedTemplate string
|
||||
}
|
||||
|
||||
type frontendTemplateData struct {
|
||||
User *oidc.UserInfo
|
||||
Prefix string
|
||||
}
|
||||
|
||||
//go:embed assets/oh_no_unauthenticated.png
|
||||
var ohNoGopher []byte
|
||||
var frontendHtmlTemplate = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
@ -69,6 +100,8 @@ var frontendHtmlTemplate = `
|
||||
/>
|
||||
<title>{{.Title}}</title>
|
||||
|
||||
<link rel="icon" type="image/png" href="/_ga/favicon.png" />
|
||||
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
|
||||
<link
|
||||
@ -79,10 +112,54 @@ var frontendHtmlTemplate = `
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
|
||||
{{if .User}}
|
||||
<script>
|
||||
window.__peridot_user__ = {
|
||||
sub: {{.User.Subject}},
|
||||
email: {{.User.Email}},
|
||||
}
|
||||
</script>
|
||||
{{end}}
|
||||
{{if .Prefix}}
|
||||
<script>window.__peridot_prefix__ = '{{.Prefix}}'.replace('\\', '');</script>
|
||||
{{end}}
|
||||
|
||||
<script src="{{.BundleJS}}"></script>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
var frontendUnauthenticated = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, viewport-fit=cover"
|
||||
/>
|
||||
<title>{{.Title}} - Unauthenticated</title>
|
||||
|
||||
<link rel="icon" type="image/png" href="/_ga/favicon.png" />
|
||||
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;600;700&display=swap"
|
||||
/>
|
||||
</head>
|
||||
<body>
|
||||
<div style="font-family:Roboto,sans-serif;display:flex;flex-flow:column;justify-content:center;align-items:center;width:100vw;height:100vh">
|
||||
<img height="120px" src="/_ga/oh_no_unauthenticated.png" /><br />
|
||||
<code>{{.}}</code><br /><br />
|
||||
|
||||
<a href="/">Go back</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
const frontendAuthCookieKey = "auth_bearer"
|
||||
|
||||
func readDir(embedfs *embed.FS, root string) ([]string, error) {
|
||||
var paths []string
|
||||
@ -104,32 +181,127 @@ func readDir(embedfs *embed.FS, root string) ([]string, error) {
|
||||
return paths, nil
|
||||
}
|
||||
|
||||
func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
|
||||
port := 9111
|
||||
func (info *FrontendInfo) renderUnauthorized(w http.ResponseWriter, message string) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
|
||||
tmpl, err := template.New("unauthorized.html").Parse(info.unauthenticatedTemplate)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = tmpl.Execute(w, message)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// frontendAuthHandler verifies that the user is authenticated
|
||||
// if not redirects to /auth/oidc/login
|
||||
func (info *FrontendInfo) frontendAuthHandler(provider *oidc.Provider, h http.Handler) http.Handler {
|
||||
excludedSuffixes := []string{
|
||||
"/auth/oidc/login",
|
||||
"/auth/oidc/callback",
|
||||
"/_ga/oh_no_unauthenticated.png",
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// if the path is excluded, then serve the request
|
||||
for _, suffix := range excludedSuffixes {
|
||||
if strings.HasSuffix(r.URL.Path, suffix) {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// get auth cookie
|
||||
authCookie, err := r.Cookie(frontendAuthCookieKey)
|
||||
if err != nil {
|
||||
// redirect to login
|
||||
http.Redirect(w, r, info.Self+"/auth/oidc/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// verify the token
|
||||
userInfo, err := provider.UserInfo(r.Context(), oauth2.StaticTokenSource(&oauth2.Token{
|
||||
AccessToken: authCookie.Value,
|
||||
TokenType: "Bearer",
|
||||
}))
|
||||
if err != nil {
|
||||
// redirect to login
|
||||
http.Redirect(w, r, info.Self+"/auth/oidc/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is in the group
|
||||
var claims oidcClaims
|
||||
err = userInfo.Claims(&claims)
|
||||
if err != nil {
|
||||
// redirect to login
|
||||
http.Redirect(w, r, info.Self+"/auth/oidc/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
groups := claims.Groups
|
||||
if info.OIDCGroup != "" {
|
||||
if !Contains(groups, info.OIDCGroup) {
|
||||
// show unauthenticated page
|
||||
info.renderUnauthorized(w, fmt.Sprintf("User is not in group %s", info.OIDCGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Add the user to the context
|
||||
ctx := context.WithValue(r.Context(), "user", userInfo)
|
||||
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func FrontendServer(info *FrontendInfo, embedfs *embed.FS) error {
|
||||
if info == nil {
|
||||
info = &FrontendInfo{}
|
||||
return errors.New("frontend info is nil")
|
||||
}
|
||||
port := info.Port
|
||||
|
||||
// Using info.Self, let's determine a path prefix
|
||||
// If info.Self is empty, then we'll use the root path
|
||||
prefix := ""
|
||||
if info.Self != "" {
|
||||
parsed, err := url.Parse(info.Self)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse self url: %w", err)
|
||||
}
|
||||
prefix = parsed.Path
|
||||
}
|
||||
|
||||
newTemplate := frontendHtmlTemplate
|
||||
newUnauthenticatedTemplate := frontendUnauthenticated
|
||||
|
||||
// Set the title
|
||||
if info.Title == "" {
|
||||
info.Title = "Peridot"
|
||||
}
|
||||
newTemplate = strings.ReplaceAll(newTemplate, "{{.Title}}", info.Title)
|
||||
newUnauthenticatedTemplate = strings.ReplaceAll(newUnauthenticatedTemplate, "{{.Title}}", info.Title)
|
||||
|
||||
pathToContent := map[string][]byte{}
|
||||
info.unauthenticatedTemplate = newUnauthenticatedTemplate
|
||||
|
||||
pathToContent := map[string][]byte{
|
||||
prefix + "/_ga/oh_no_unauthenticated.png": ohNoGopher,
|
||||
}
|
||||
|
||||
// Read the files from the embedfs
|
||||
paths, err := readDir(embedfs, "bundle")
|
||||
if err != nil {
|
||||
LogFatalf("failed to read embedfs: %v", err)
|
||||
return fmt.Errorf("failed to read embedfs: %w", err)
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
content, err := embedfs.ReadFile(path)
|
||||
if err != nil {
|
||||
LogFatalf("failed to read embedfs: %v", err)
|
||||
return fmt.Errorf("failed to read embedfs: %w", err)
|
||||
}
|
||||
|
||||
// Sha256 hash of the content to add to name
|
||||
@ -139,7 +311,7 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
|
||||
|
||||
ext := filepath.Ext(path)
|
||||
noExtName := path[:len(path)-len(ext)]
|
||||
newPath := fmt.Sprintf("/_ga/%s.%s%s", noExtName, hashSum[:8], ext)
|
||||
newPath := fmt.Sprintf("%s/_ga/%s.%s%s", prefix, noExtName, hashSum[:8], ext)
|
||||
|
||||
pathToContent[newPath] = content
|
||||
|
||||
@ -163,7 +335,7 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
|
||||
}
|
||||
|
||||
// Serve the content
|
||||
http.HandleFunc("/_ga/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.HandleFunc(prefix+"/_ga/", func(w http.ResponseWriter, r *http.Request) {
|
||||
mimeType := mime.TypeByExtension(filepath.Ext(r.URL.Path))
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
@ -172,24 +344,174 @@ func FrontendServer(info *FrontendInfo, embedfs *embed.FS) {
|
||||
_, _ = w.Write([]byte(pathToContent[r.URL.Path]))
|
||||
})
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.HandleFunc(prefix+"/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
_, _ = w.Write([]byte(newTemplate))
|
||||
|
||||
tmpl, err := template.New("index.html").Parse(newTemplate)
|
||||
if err != nil {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("Failed to parse template: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// If user is in context, then add it to the template
|
||||
data := frontendTemplateData{
|
||||
Prefix: prefix,
|
||||
}
|
||||
|
||||
user := r.Context().Value("user")
|
||||
if user != nil {
|
||||
data.User = user.(*oidc.UserInfo)
|
||||
}
|
||||
|
||||
err = tmpl.Execute(w, data)
|
||||
if err != nil {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("Failed to execute template: %v", err))
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// Handle other _ga meta routes
|
||||
http.HandleFunc("/_ga/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.HandleFunc(prefix+"/_ga/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
// Handle auth routes
|
||||
http.HandleFunc("/auth/oidc/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
var provider *oidc.Provider
|
||||
if !info.NoAuth {
|
||||
ctx := context.TODO()
|
||||
provider, err = oidc.NewProvider(ctx, info.OIDCIssuer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create oidc provider: %w", err)
|
||||
}
|
||||
|
||||
})
|
||||
redirectURL := info.Self + "/auth/oidc/callback"
|
||||
|
||||
LogInfof("starting frontend server on port %d", port)
|
||||
if err := http.ListenAndServe(":"+strconv.Itoa(port), nil); err != nil {
|
||||
LogFatalf("failed to start frontend server: %v", err)
|
||||
oauth2Config := oauth2.Config{
|
||||
ClientID: info.OIDCClientID,
|
||||
ClientSecret: info.OIDCClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "groups"},
|
||||
}
|
||||
|
||||
http.HandleFunc(prefix+"/auth/oidc/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate a random state
|
||||
state := ""
|
||||
for i := 0; i < 16; i++ {
|
||||
state += strconv.Itoa(rand.Intn(10))
|
||||
}
|
||||
|
||||
// Generate the auth url
|
||||
authURL := oauth2Config.AuthCodeURL(state)
|
||||
|
||||
// Set the state cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "auth_state",
|
||||
Value: state,
|
||||
Path: "/",
|
||||
// expires in 2 minutes
|
||||
MaxAge: 120,
|
||||
// secure if self is https
|
||||
Secure: strings.HasPrefix(info.Self, "https://"),
|
||||
})
|
||||
|
||||
// Redirect to the auth url
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
})
|
||||
|
||||
http.HandleFunc(prefix+"/auth/oidc/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the state cookie
|
||||
stateCookie, err := r.Cookie("auth_state")
|
||||
if err != nil {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("Failed to get state cookie: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Get the state query param
|
||||
stateQueryParam := r.URL.Query().Get("state")
|
||||
if stateQueryParam == "" {
|
||||
info.renderUnauthorized(w, "No state parameter in query")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the state cookie and state query param match
|
||||
if stateCookie.Value != stateQueryParam {
|
||||
info.renderUnauthorized(w, "State cookie and state query param do not match")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange the code for a token
|
||||
token, err := oauth2Config.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("Failed to exchange code: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
accessToken := token.AccessToken
|
||||
userInfo, err := provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token))
|
||||
if err != nil {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("Failed to get userinfo: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is in the group
|
||||
if info.OIDCGroup != "" {
|
||||
var claims oidcClaims
|
||||
err := userInfo.Claims(&claims)
|
||||
if err != nil {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("Failed to get claims: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
groups := claims.Groups
|
||||
|
||||
found := false
|
||||
for _, group := range groups {
|
||||
if group == info.OIDCGroup {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
info.renderUnauthorized(w, fmt.Sprintf("User is not in group %s", info.OIDCGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set the auth cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: frontendAuthCookieKey,
|
||||
Value: accessToken,
|
||||
Path: "/",
|
||||
// expires in 2 hours
|
||||
MaxAge: 7200,
|
||||
// secure if self is https
|
||||
Secure: strings.HasPrefix(info.Self, "https://"),
|
||||
})
|
||||
|
||||
// Redirect to self, this is due to the "root" not being / for all apps
|
||||
http.Redirect(w, r, info.Self, http.StatusFound)
|
||||
})
|
||||
}
|
||||
|
||||
var handler http.Handler = nil
|
||||
// if auth is enabled as well as AllowUnauthenticated is false, then wrap the handler with the auth handler
|
||||
if !info.NoAuth && !info.AllowUnauthenticated {
|
||||
handler = info.frontendAuthHandler(provider, http.DefaultServeMux)
|
||||
} else {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
info.MuxHandler = handler
|
||||
|
||||
if !info.NoRun {
|
||||
LogInfof("starting frontend server on port %d", port)
|
||||
if err := http.ListenAndServe(":"+strconv.Itoa(port), handler); err != nil {
|
||||
return fmt.Errorf("failed to start frontend server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -16,16 +16,17 @@ package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -47,6 +48,7 @@ type GRPCServer struct {
|
||||
grpcPort int
|
||||
gatewayPort int
|
||||
noGrpcGateway bool
|
||||
noMetrics bool
|
||||
}
|
||||
|
||||
type GrpcEndpointRegister func(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error
|
||||
@ -129,6 +131,39 @@ func WithNoGRPCGateway() GRPCServerOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithNoMetrics disables the Prometheus metrics for the gRPC server.
|
||||
func WithNoMetrics() GRPCServerOption {
|
||||
return func(g *GRPCServer) {
|
||||
g.noMetrics = true
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultServeMuxOptions() []runtime.ServeMuxOption {
|
||||
return []runtime.ServeMuxOption{
|
||||
runtime.WithIncomingHeaderMatcher(func(s string) (string, bool) {
|
||||
switch strings.ToLower(s) {
|
||||
case "authorization",
|
||||
"cookie":
|
||||
return s, true
|
||||
}
|
||||
|
||||
if strings.ToLower(s) == "content-type" {
|
||||
return "original-content-type", true
|
||||
}
|
||||
|
||||
return s, false
|
||||
}),
|
||||
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{
|
||||
MarshalOptions: protojson.MarshalOptions{
|
||||
EmitUnpopulated: false,
|
||||
},
|
||||
UnmarshalOptions: protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewGRPCServer creates a new gRPC-server with gRPC-gateway, default interceptors
|
||||
// and exposed Prometheus metrics.
|
||||
func NewGRPCServer(opts ...GRPCServerOption) (*GRPCServer, error) {
|
||||
@ -157,6 +192,9 @@ func NewGRPCServer(opts ...GRPCServerOption) (*GRPCServer, error) {
|
||||
if g.gatewayPort == 0 {
|
||||
g.gatewayPort = g.grpcPort + 1
|
||||
}
|
||||
if len(g.muxOptions) == 0 {
|
||||
g.muxOptions = DefaultServeMuxOptions()
|
||||
}
|
||||
|
||||
// Always prepend the insecure dial option
|
||||
// RESF deploys with Istio, which handles mTLS
|
||||
@ -180,15 +218,13 @@ func NewGRPCServer(opts ...GRPCServerOption) (*GRPCServer, error) {
|
||||
|
||||
g.server = grpc.NewServer(g.serverOptions...)
|
||||
|
||||
if !g.noGrpcGateway {
|
||||
g.gatewayMux = runtime.NewServeMux(g.muxOptions...)
|
||||
g.gatewayMux = runtime.NewServeMux(g.muxOptions...)
|
||||
|
||||
// Create gateway client connection
|
||||
var err error
|
||||
g.gatewayClientConn, err = grpc.Dial("localhost:"+strconv.Itoa(g.grpcPort), g.dialOptions...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create gateway client connection
|
||||
var err error
|
||||
g.gatewayClientConn, err = grpc.Dial("localhost:"+strconv.Itoa(g.grpcPort), g.dialOptions...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return g, nil
|
||||
@ -199,10 +235,6 @@ func (g *GRPCServer) RegisterService(register func(*grpc.Server)) {
|
||||
}
|
||||
|
||||
func (g *GRPCServer) GatewayEndpoints(registerEndpoints ...GrpcEndpointRegister) error {
|
||||
if g.noGrpcGateway {
|
||||
return errors.New("gRPC-gateway is disabled")
|
||||
}
|
||||
|
||||
for _, register := range registerEndpoints {
|
||||
if err := register(context.Background(), g.gatewayMux, g.gatewayClientConn); err != nil {
|
||||
return err
|
||||
@ -212,6 +244,10 @@ func (g *GRPCServer) GatewayEndpoints(registerEndpoints ...GrpcEndpointRegister)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GRPCServer) GatewayMux() *runtime.ServeMux {
|
||||
return g.gatewayMux
|
||||
}
|
||||
|
||||
func (g *GRPCServer) Start() error {
|
||||
// Create gRPC listener
|
||||
grpcLis, err := net.Listen("tcp", ":"+strconv.Itoa(g.grpcPort))
|
||||
@ -254,17 +290,19 @@ func (g *GRPCServer) Start() error {
|
||||
}
|
||||
|
||||
// Serve proxmux
|
||||
wg.Add(1)
|
||||
go func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
if !g.noMetrics {
|
||||
wg.Add(1)
|
||||
go func(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
promMux := http.NewServeMux()
|
||||
promMux.Handle("/metrics", promhttp.Handler())
|
||||
err := http.ListenAndServe(":7332", promMux)
|
||||
if err != nil {
|
||||
LogFatalf("Prometheus mux failed to serve: %v", err.Error())
|
||||
}
|
||||
}(&wg)
|
||||
promMux := http.NewServeMux()
|
||||
promMux.Handle("/metrics", promhttp.Handler())
|
||||
err := http.ListenAndServe(":7332", promMux)
|
||||
if err != nil {
|
||||
LogFatalf("Prometheus mux failed to serve: %v", err.Error())
|
||||
}
|
||||
}(&wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user