mirror of
https://github.com/rocky-linux/peridot.git
synced 2024-11-24 14:11:25 +00:00
369 lines
11 KiB
Go
369 lines
11 KiB
Go
|
// Copyright 2023 Google LLC
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
package auth
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"mime"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"cloud.google.com/go/auth/internal"
|
||
|
)
|
||
|
|
||
|
// AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for
|
||
|
// OAuth consent at the specified auth code URL and returns an auth code and
|
||
|
// state upon approval.
|
||
|
type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
|
||
|
|
||
|
// Options3LO are the options for doing a 3-legged OAuth2 flow.
|
||
|
type Options3LO struct {
|
||
|
// ClientID is the application's ID.
|
||
|
ClientID string
|
||
|
// ClientSecret is the application's secret. Not required if AuthHandlerOpts
|
||
|
// is set.
|
||
|
ClientSecret string
|
||
|
// AuthURL is the URL for authenticating.
|
||
|
AuthURL string
|
||
|
// TokenURL is the URL for retrieving a token.
|
||
|
TokenURL string
|
||
|
// AuthStyle is used to describe how to client info in the token request.
|
||
|
AuthStyle Style
|
||
|
// RefreshToken is the token used to refresh the credential. Not required
|
||
|
// if AuthHandlerOpts is set.
|
||
|
RefreshToken string
|
||
|
// RedirectURL is the URL to redirect users to. Optional.
|
||
|
RedirectURL string
|
||
|
// Scopes specifies requested permissions for the Token. Optional.
|
||
|
Scopes []string
|
||
|
|
||
|
// URLParams are the set of values to apply to the token exchange. Optional.
|
||
|
URLParams url.Values
|
||
|
// Client is the client to be used to make the underlying token requests.
|
||
|
// Optional.
|
||
|
Client *http.Client
|
||
|
// EarlyTokenExpiry is the time before the token expires that it should be
|
||
|
// refreshed. If not set the default value is 3 minutes and 45 seconds.
|
||
|
// Optional.
|
||
|
EarlyTokenExpiry time.Duration
|
||
|
|
||
|
// AuthHandlerOpts provides a set of options for doing a
|
||
|
// 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional.
|
||
|
AuthHandlerOpts *AuthorizationHandlerOptions
|
||
|
}
|
||
|
|
||
|
func (o *Options3LO) validate() error {
|
||
|
if o == nil {
|
||
|
return errors.New("auth: options must be provided")
|
||
|
}
|
||
|
if o.ClientID == "" {
|
||
|
return errors.New("auth: client ID must be provided")
|
||
|
}
|
||
|
if o.AuthHandlerOpts == nil && o.ClientSecret == "" {
|
||
|
return errors.New("auth: client secret must be provided")
|
||
|
}
|
||
|
if o.AuthURL == "" {
|
||
|
return errors.New("auth: auth URL must be provided")
|
||
|
}
|
||
|
if o.TokenURL == "" {
|
||
|
return errors.New("auth: token URL must be provided")
|
||
|
}
|
||
|
if o.AuthStyle == StyleUnknown {
|
||
|
return errors.New("auth: auth style must be provided")
|
||
|
}
|
||
|
if o.AuthHandlerOpts == nil && o.RefreshToken == "" {
|
||
|
return errors.New("auth: refresh token must be provided")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// PKCEOptions holds parameters to support PKCE.
|
||
|
type PKCEOptions struct {
|
||
|
// Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier.
|
||
|
Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier.
|
||
|
// ChallengeMethod is the encryption method (ex. S256).
|
||
|
ChallengeMethod string
|
||
|
// Verifier is the original, non-encrypted secret.
|
||
|
Verifier string // The original, non-encrypted secret.
|
||
|
}
|
||
|
|
||
|
type tokenJSON struct {
|
||
|
AccessToken string `json:"access_token"`
|
||
|
TokenType string `json:"token_type"`
|
||
|
RefreshToken string `json:"refresh_token"`
|
||
|
ExpiresIn int `json:"expires_in"`
|
||
|
// error fields
|
||
|
ErrorCode string `json:"error"`
|
||
|
ErrorDescription string `json:"error_description"`
|
||
|
ErrorURI string `json:"error_uri"`
|
||
|
}
|
||
|
|
||
|
func (e *tokenJSON) expiry() (t time.Time) {
|
||
|
if v := e.ExpiresIn; v != 0 {
|
||
|
return time.Now().Add(time.Duration(v) * time.Second)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (o *Options3LO) client() *http.Client {
|
||
|
if o.Client != nil {
|
||
|
return o.Client
|
||
|
}
|
||
|
return internal.CloneDefaultClient()
|
||
|
}
|
||
|
|
||
|
// authCodeURL returns a URL that points to a OAuth2 consent page.
|
||
|
func (o *Options3LO) authCodeURL(state string, values url.Values) string {
|
||
|
var buf bytes.Buffer
|
||
|
buf.WriteString(o.AuthURL)
|
||
|
v := url.Values{
|
||
|
"response_type": {"code"},
|
||
|
"client_id": {o.ClientID},
|
||
|
}
|
||
|
if o.RedirectURL != "" {
|
||
|
v.Set("redirect_uri", o.RedirectURL)
|
||
|
}
|
||
|
if len(o.Scopes) > 0 {
|
||
|
v.Set("scope", strings.Join(o.Scopes, " "))
|
||
|
}
|
||
|
if state != "" {
|
||
|
v.Set("state", state)
|
||
|
}
|
||
|
if o.AuthHandlerOpts != nil {
|
||
|
if o.AuthHandlerOpts.PKCEOpts != nil &&
|
||
|
o.AuthHandlerOpts.PKCEOpts.Challenge != "" {
|
||
|
v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge)
|
||
|
}
|
||
|
if o.AuthHandlerOpts.PKCEOpts != nil &&
|
||
|
o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" {
|
||
|
v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod)
|
||
|
}
|
||
|
}
|
||
|
for k := range values {
|
||
|
v.Set(k, v.Get(k))
|
||
|
}
|
||
|
if strings.Contains(o.AuthURL, "?") {
|
||
|
buf.WriteByte('&')
|
||
|
} else {
|
||
|
buf.WriteByte('?')
|
||
|
}
|
||
|
buf.WriteString(v.Encode())
|
||
|
return buf.String()
|
||
|
}
|
||
|
|
||
|
// New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2
|
||
|
// configuration. The TokenProvider is caches and auto-refreshes tokens by
|
||
|
// default.
|
||
|
func New3LOTokenProvider(opts *Options3LO) (TokenProvider, error) {
|
||
|
if err := opts.validate(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if opts.AuthHandlerOpts != nil {
|
||
|
return new3LOTokenProviderWithAuthHandler(opts), nil
|
||
|
}
|
||
|
return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: opts.RefreshToken, client: opts.client()}, &CachedTokenProviderOptions{
|
||
|
ExpireEarly: opts.EarlyTokenExpiry,
|
||
|
}), nil
|
||
|
}
|
||
|
|
||
|
// AuthorizationHandlerOptions provides a set of options to specify for doing a
|
||
|
// 3-legged OAuth2 flow with a custom [AuthorizationHandler].
|
||
|
type AuthorizationHandlerOptions struct {
|
||
|
// AuthorizationHandler specifies the handler used to for the authorization
|
||
|
// part of the flow.
|
||
|
Handler AuthorizationHandler
|
||
|
// State is used verify that the "state" is identical in the request and
|
||
|
// response before exchanging the auth code for OAuth2 token.
|
||
|
State string
|
||
|
// PKCEOpts allows setting configurations for PKCE. Optional.
|
||
|
PKCEOpts *PKCEOptions
|
||
|
}
|
||
|
|
||
|
func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider {
|
||
|
return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{
|
||
|
ExpireEarly: opts.EarlyTokenExpiry,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// exchange handles the final exchange portion of the 3lo flow. Returns a Token,
|
||
|
// refreshToken, and error.
|
||
|
func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) {
|
||
|
// Build request
|
||
|
v := url.Values{
|
||
|
"grant_type": {"authorization_code"},
|
||
|
"code": {code},
|
||
|
}
|
||
|
if o.RedirectURL != "" {
|
||
|
v.Set("redirect_uri", o.RedirectURL)
|
||
|
}
|
||
|
if o.AuthHandlerOpts != nil &&
|
||
|
o.AuthHandlerOpts.PKCEOpts != nil &&
|
||
|
o.AuthHandlerOpts.PKCEOpts.Verifier != "" {
|
||
|
v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier)
|
||
|
}
|
||
|
for k := range o.URLParams {
|
||
|
v.Set(k, o.URLParams.Get(k))
|
||
|
}
|
||
|
return fetchToken(ctx, o, v)
|
||
|
}
|
||
|
|
||
|
// This struct is not safe for concurrent access alone, but the way it is used
|
||
|
// in this package by wrapping it with a cachedTokenProvider makes it so.
|
||
|
type tokenProvider3LO struct {
|
||
|
opts *Options3LO
|
||
|
client *http.Client
|
||
|
refreshToken string
|
||
|
}
|
||
|
|
||
|
func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) {
|
||
|
if tp.refreshToken == "" {
|
||
|
return nil, errors.New("auth: token expired and refresh token is not set")
|
||
|
}
|
||
|
v := url.Values{
|
||
|
"grant_type": {"refresh_token"},
|
||
|
"refresh_token": {tp.refreshToken},
|
||
|
}
|
||
|
for k := range tp.opts.URLParams {
|
||
|
v.Set(k, tp.opts.URLParams.Get(k))
|
||
|
}
|
||
|
|
||
|
tk, rt, err := fetchToken(ctx, tp.opts, v)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if tp.refreshToken != rt && rt != "" {
|
||
|
tp.refreshToken = rt
|
||
|
}
|
||
|
return tk, err
|
||
|
}
|
||
|
|
||
|
type tokenProviderWithHandler struct {
|
||
|
opts *Options3LO
|
||
|
state string
|
||
|
}
|
||
|
|
||
|
func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) {
|
||
|
url := tp.opts.authCodeURL(tp.state, nil)
|
||
|
code, state, err := tp.opts.AuthHandlerOpts.Handler(url)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if state != tp.state {
|
||
|
return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow")
|
||
|
}
|
||
|
tok, _, err := tp.opts.exchange(ctx, code)
|
||
|
return tok, err
|
||
|
}
|
||
|
|
||
|
// fetchToken returns a Token, refresh token, and/or an error.
|
||
|
func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) {
|
||
|
var refreshToken string
|
||
|
if o.AuthStyle == StyleInParams {
|
||
|
if o.ClientID != "" {
|
||
|
v.Set("client_id", o.ClientID)
|
||
|
}
|
||
|
if o.ClientSecret != "" {
|
||
|
v.Set("client_secret", o.ClientSecret)
|
||
|
}
|
||
|
}
|
||
|
req, err := http.NewRequestWithContext(ctx, "POST", o.TokenURL, strings.NewReader(v.Encode()))
|
||
|
if err != nil {
|
||
|
return nil, refreshToken, err
|
||
|
}
|
||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
|
if o.AuthStyle == StyleInHeader {
|
||
|
req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
|
||
|
}
|
||
|
|
||
|
// Make request
|
||
|
resp, body, err := internal.DoRequest(o.client(), req)
|
||
|
if err != nil {
|
||
|
return nil, refreshToken, err
|
||
|
}
|
||
|
failureStatus := resp.StatusCode < 200 || resp.StatusCode > 299
|
||
|
tokError := &Error{
|
||
|
Response: resp,
|
||
|
Body: body,
|
||
|
}
|
||
|
|
||
|
var token *Token
|
||
|
// errors ignored because of default switch on content
|
||
|
content, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||
|
switch content {
|
||
|
case "application/x-www-form-urlencoded", "text/plain":
|
||
|
// some endpoints return a query string
|
||
|
vals, err := url.ParseQuery(string(body))
|
||
|
if err != nil {
|
||
|
if failureStatus {
|
||
|
return nil, refreshToken, tokError
|
||
|
}
|
||
|
return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err)
|
||
|
}
|
||
|
tokError.code = vals.Get("error")
|
||
|
tokError.description = vals.Get("error_description")
|
||
|
tokError.uri = vals.Get("error_uri")
|
||
|
token = &Token{
|
||
|
Value: vals.Get("access_token"),
|
||
|
Type: vals.Get("token_type"),
|
||
|
Metadata: make(map[string]interface{}, len(vals)),
|
||
|
}
|
||
|
for k, v := range vals {
|
||
|
token.Metadata[k] = v
|
||
|
}
|
||
|
refreshToken = vals.Get("refresh_token")
|
||
|
e := vals.Get("expires_in")
|
||
|
expires, _ := strconv.Atoi(e)
|
||
|
if expires != 0 {
|
||
|
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
|
||
|
}
|
||
|
default:
|
||
|
var tj tokenJSON
|
||
|
if err = json.Unmarshal(body, &tj); err != nil {
|
||
|
if failureStatus {
|
||
|
return nil, refreshToken, tokError
|
||
|
}
|
||
|
return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err)
|
||
|
}
|
||
|
tokError.code = tj.ErrorCode
|
||
|
tokError.description = tj.ErrorDescription
|
||
|
tokError.uri = tj.ErrorURI
|
||
|
token = &Token{
|
||
|
Value: tj.AccessToken,
|
||
|
Type: tj.TokenType,
|
||
|
Expiry: tj.expiry(),
|
||
|
Metadata: make(map[string]interface{}),
|
||
|
}
|
||
|
json.Unmarshal(body, &token.Metadata) // optional field, skip err check
|
||
|
refreshToken = tj.RefreshToken
|
||
|
}
|
||
|
// according to spec, servers should respond status 400 in error case
|
||
|
// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
|
||
|
// but some unorthodox servers respond 200 in error case
|
||
|
if failureStatus || tokError.code != "" {
|
||
|
return nil, refreshToken, tokError
|
||
|
}
|
||
|
if token.Value == "" {
|
||
|
return nil, refreshToken, errors.New("auth: server response missing access_token")
|
||
|
}
|
||
|
return token, refreshToken, nil
|
||
|
}
|