Merge pull request #61 from mstg/module-v3-support

Module v3 support and misc fixes
This commit is contained in:
resf-prow[bot] 2022-11-04 02:46:43 +00:00 committed by GitHub
commit 6e752a3704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
389 changed files with 82132 additions and 19721 deletions

2
.envrc
View File

@ -9,3 +9,5 @@ export IBAZEL_USE_LEGACY_WATCHER=0
export STABLE_REGISTRY_SECRET="none" export STABLE_REGISTRY_SECRET="none"
export STABLE_OCI_REGISTRY_NO_NESTED_SUPPORT_IN_2022_SHAME_ON_YOU_AWS="true" export STABLE_OCI_REGISTRY_NO_NESTED_SUPPORT_IN_2022_SHAME_ON_YOU_AWS="true"
PATH_add hack PATH_add hack
[[ -f .envrc.local ]] && source .envrc.local

2
.gitignore vendored
View File

@ -114,3 +114,5 @@ fabric.properties
.ijwb/.idea/runConfigurations.xml .ijwb/.idea/runConfigurations.xml
.ijwb/.idea/google-java-format.xml .ijwb/.idea/google-java-format.xml
.ijwb/.idea/dataSources.xml .ijwb/.idea/dataSources.xml
.envrc.local

View File

@ -1,5 +1,5 @@
<component name="CopyrightManager"> <component name="CopyrightManager">
<settings> <settings default="Peridot">
<module2copyright> <module2copyright>
<element module="OriginalFiles" copyright="Peridot" /> <element module="OriginalFiles" copyright="Peridot" />
</module2copyright> </module2copyright>

View File

@ -109,11 +109,12 @@ local manifestYamlStream = function (value, indent_array_in_object=false, c_docu
image: image, image: image,
tag: tag, tag: tag,
}; };
local istio_mode = if helm_mode then false else if utils.local_image then false else true;
{ {
[nssa]: (if helm_mode then '{{ if not .Values.serviceAccountName }}\n' else '') + manifestYamlStream([ [nssa]: (if helm_mode then '{{ if not .Values.serviceAccountName }}\n' else '') + manifestYamlStream([
// disable namespace creation in helm mode // disable namespace creation in helm mode
if !helm_mode then kubernetes.define_namespace(metadata.namespace, infolabels), if !helm_mode then kubernetes.define_namespace(metadata.namespace, infolabels + { annotations: { 'linkerd.io/inject': 'enabled' } }),
kubernetes.define_service_account( kubernetes.define_service_account(
metadata { metadata {
name: fixed.name, name: fixed.name,
@ -285,22 +286,22 @@ local manifestYamlStream = function (value, indent_array_in_object=false, c_docu
selector=metadata.name, selector=metadata.name,
env=mappings.get_env_from_svc(srv.name) env=mappings.get_env_from_svc(srv.name)
) for srv in services] + ) for srv in services] +
if !helm_mode then [] else [if std.objectHas(srv, 'expose') && srv.expose then kubernetes.define_ingress( if istio_mode then [] else [if std.objectHas(srv, 'expose') && srv.expose then kubernetes.define_ingress(
metadata { metadata {
name: srv.name, name: srv.name,
annotations: ingress_annotations + { annotations: ingress_annotations + {
'kubernetes.io/ingress.class': '{{ .Values.ingressClass | default !"!" }}', 'kubernetes.io/ingress.class': if helm_mode then '{{ .Values.ingressClass | default !"!" }}' else 'kong',
// Secure only by default // Secure only by default
// This produces https, grpcs, etc. // This produces https, grpcs, etc.
// todo(mustafa): check if we need to add an exemption to a protocol (TCP comes to mind) // todo(mustafa): check if we need to add an exemption to a protocol (TCP comes to mind)
'konghq.com/protocols': '{{ .Values.kongProtocols | default !"%ss!"' % std.strReplace(std.strReplace(std.strReplace(srv.name, metadata.name, ''), stage, ''), '-', ''), 'konghq.com/protocols': (if helm_mode then '{{ .Values.kongProtocols | default !"%ss!" }}' else '%ss') % std.strReplace(std.strReplace(std.strReplace(srv.name, metadata.name, ''), stage, ''), '-', ''),
} }
}, },
host=if helm_mode then '{{ .Values.%s.ingressHost }}' % srv.portName else mappings.get(srv.name, user), host=if helm_mode then '{{ .Values.%s.ingressHost }}' % srv.portName else mappings.get(srv.name, user),
port=srv.port, port=srv.port,
srvName=srv.name + '-service', srvName=srv.name + '-service',
) else null for srv in services] + ) else null for srv in services] +
if helm_mode then [] else [kubernetes.define_virtual_service(metadata { name: srv.name + '-internal' }, { if !istio_mode then [] else [kubernetes.define_virtual_service(metadata { name: srv.name + '-internal' }, {
hosts: [vshost(srv)], hosts: [vshost(srv)],
gateways: [], gateways: [],
http: [ http: [
@ -317,7 +318,7 @@ local manifestYamlStream = function (value, indent_array_in_object=false, c_docu
}, },
], ],
},) for srv in services] + },) for srv in services] +
if helm_mode then [] else [if std.objectHas(srv, 'expose') && srv.expose then kubernetes.define_virtual_service( if !istio_mode then [] else [if std.objectHas(srv, 'expose') && srv.expose then kubernetes.define_virtual_service(
metadata { metadata {
name: srv.name, name: srv.name,
annotations: { annotations: {
@ -342,7 +343,7 @@ local manifestYamlStream = function (value, indent_array_in_object=false, c_docu
], ],
} }
) else null for srv in services] + ) else null for srv in services] +
if helm_mode then [] else [{ if !istio_mode then [] else [{
apiVersion: 'security.istio.io/v1beta1', apiVersion: 'security.istio.io/v1beta1',
kind: 'RequestAuthentication', kind: 'RequestAuthentication',
metadata: metadata { metadata: metadata {
@ -363,7 +364,7 @@ local manifestYamlStream = function (value, indent_array_in_object=false, c_docu
}] else [], }] else [],
}, },
} for srv in services] + } for srv in services] +
if helm_mode then [] else [{ if !istio_mode then [] else [{
apiVersion: 'security.istio.io/v1beta1', apiVersion: 'security.istio.io/v1beta1',
kind: 'AuthorizationPolicy', kind: 'AuthorizationPolicy',
metadata: metadata { metadata: metadata {
@ -388,7 +389,7 @@ local manifestYamlStream = function (value, indent_array_in_object=false, c_docu
}], }],
}, },
} for srv in services] + } for srv in services] +
if helm_mode then [] else [kubernetes.define_destination_rule(metadata { name: srv.name }, { if !istio_mode then [] else [kubernetes.define_destination_rule(metadata { name: srv.name }, {
host: vshost(srv), host: vshost(srv),
trafficPolicy: { trafficPolicy: {
tls: { tls: {

View File

@ -1,6 +1,7 @@
# sync-ignore-file: true local local_domain = std.extVar("local_domain");
{ {
local_domain: '.pdev.resf.localhost', local_domain: local_domain,
default_domain: '.build.resf.org', default_domain: '.build.resf.org',
service_mappings: { service_mappings: {
'peridotserver-http': { 'peridotserver-http': {

5
go.mod
View File

@ -13,7 +13,7 @@ require (
github.com/antchfx/xmlquery v1.3.6 // indirect github.com/antchfx/xmlquery v1.3.6 // indirect
github.com/authzed/authzed-go v0.3.0 github.com/authzed/authzed-go v0.3.0
github.com/authzed/grpcutil v0.0.0-20211115181027-063820eb2511 github.com/authzed/grpcutil v0.0.0-20211115181027-063820eb2511
github.com/aws/aws-sdk-go v1.36.12 github.com/aws/aws-sdk-go v1.44.129
github.com/cavaliergopher/rpm v1.2.0 github.com/cavaliergopher/rpm v1.2.0
github.com/coreos/go-oidc/v3 v3.0.0 github.com/coreos/go-oidc/v3 v3.0.0
github.com/fatih/color v1.12.0 github.com/fatih/color v1.12.0
@ -37,7 +37,7 @@ require (
github.com/pelletier/go-toml v1.8.1 // indirect github.com/pelletier/go-toml v1.8.1 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/rocky-linux/srpmproc v0.3.16 github.com/rocky-linux/srpmproc v0.4.1
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
@ -50,6 +50,7 @@ require (
go.temporal.io/sdk v1.13.1 go.temporal.io/sdk v1.13.1
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2 // indirect golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2 // indirect
google.golang.org/genproto v0.0.0-20211104193956-4c6863e31247 google.golang.org/genproto v0.0.0-20211104193956-4c6863e31247
google.golang.org/grpc v1.44.0 google.golang.org/grpc v1.44.0

6
go.sum
View File

@ -113,6 +113,8 @@ github.com/aws/aws-sdk-go v1.34.13/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve
github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48=
github.com/aws/aws-sdk-go v1.36.12 h1:YJpKFEMbqEoo+incs5qMe61n1JH3o4O1IMkMexLzJG8= github.com/aws/aws-sdk-go v1.36.12 h1:YJpKFEMbqEoo+incs5qMe61n1JH3o4O1IMkMexLzJG8=
github.com/aws/aws-sdk-go v1.36.12/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.36.12/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
github.com/aws/aws-sdk-go v1.44.129 h1:yld8Rc8OCahLtenY1mnve4w1jVeBu/rSiscGzodaDOs=
github.com/aws/aws-sdk-go v1.44.129/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@ -663,6 +665,8 @@ github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k= github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
github.com/rocky-linux/srpmproc v0.3.16 h1:kxJEiQsZ0DcMhX0vY482n82XvjPiP2WifxI3NYuyLLM= github.com/rocky-linux/srpmproc v0.3.16 h1:kxJEiQsZ0DcMhX0vY482n82XvjPiP2WifxI3NYuyLLM=
github.com/rocky-linux/srpmproc v0.3.16/go.mod h1:vWZzxPTfxh4pmfr5Mw20FyrqyKsbGHzDwOlN+W5EMpw= github.com/rocky-linux/srpmproc v0.3.16/go.mod h1:vWZzxPTfxh4pmfr5Mw20FyrqyKsbGHzDwOlN+W5EMpw=
github.com/rocky-linux/srpmproc v0.4.1 h1:qcq7bGLplKbu+dSKQ9VBwcTao3OqPNb6rdKz58MCFLA=
github.com/rocky-linux/srpmproc v0.4.1/go.mod h1:x8Z2wqhV2JqRnYMhYz3thOQkfsSWjJkyX8DVGDPOb48=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -981,6 +985,8 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View File

@ -2,8 +2,14 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "modulemd", name = "modulemd",
srcs = ["modulemd.go"], srcs = [
"modulemd.go",
"v3.go",
],
importpath = "peridot.resf.org/modulemd", importpath = "peridot.resf.org/modulemd",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//vendor/gopkg.in/yaml.v3:yaml_v3"], deps = [
"//vendor/github.com/go-git/go-billy/v5:go-billy",
"//vendor/gopkg.in/yaml.v3:yaml_v3",
],
) )

View File

@ -32,6 +32,7 @@ package modulemd
import ( import (
"fmt" "fmt"
"github.com/go-git/go-billy/v5"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -152,6 +153,11 @@ type ModuleMd struct {
Data *Data `yaml:"data,omitempty"` Data *Data `yaml:"data,omitempty"`
} }
type DetectVersionDocument struct {
Document string `yaml:"document,omitempty"`
Version int `yaml:"version,omitempty"`
}
type DefaultsData struct { type DefaultsData struct {
Module string `yaml:"module,omitempty"` Module string `yaml:"module,omitempty"`
Stream string `yaml:"stream,omitempty"` Stream string `yaml:"stream,omitempty"`
@ -165,11 +171,71 @@ type Defaults struct {
} }
func Parse(input []byte) (*ModuleMd, error) { func Parse(input []byte) (*ModuleMd, error) {
var ret ModuleMd var detect DetectVersionDocument
err := yaml.Unmarshal(input, &ret) err := yaml.Unmarshal(input, &detect)
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing modulemd: %s", err) return nil, fmt.Errorf("error detecting document version: %s", err)
}
var ret ModuleMd
if detect.Version == 2 {
err = yaml.Unmarshal(input, &ret)
if err != nil {
return nil, fmt.Errorf("error parsing modulemd: %s", err)
}
} else if detect.Version == 3 {
var v3 V3
err = yaml.Unmarshal(input, &v3)
if err != nil {
return nil, fmt.Errorf("error parsing modulemd: %s", err)
}
ret = ModuleMd{
Document: v3.Document,
Version: v3.Version,
Data: &Data{
Name: v3.Data.Name,
Stream: v3.Data.Stream,
Summary: v3.Data.Summary,
Description: v3.Data.Description,
License: &License{
Module: v3.Data.License,
},
Xmd: v3.Data.Xmd,
References: v3.Data.References,
Profiles: v3.Data.Profiles,
Profile: v3.Data.Profile,
API: v3.Data.API,
Filter: v3.Data.Filter,
BuildOpts: &BuildOpts{
Rpms: v3.Data.Configurations[0].BuildOpts.Rpms,
Arches: v3.Data.Configurations[0].BuildOpts.Arches,
},
Components: v3.Data.Components,
},
}
} }
return &ret, nil return &ret, nil
} }
func (m *ModuleMd) Marshal(fs billy.Filesystem, path string) error {
bts, err := yaml.Marshal(m)
if err != nil {
return err
}
_ = fs.Remove(path)
f, err := fs.Create(path)
if err != nil {
return err
}
_, err = f.Write(bts)
if err != nil {
return err
}
_ = f.Close()
return nil
}

62
modulemd/v3.go Normal file
View File

@ -0,0 +1,62 @@
// Copyright (c) All respective contributors to the Peridot Project. All rights reserved.
// Copyright (c) 2021-2022 Rocky Enterprise Software Foundation, Inc. All rights reserved.
// Copyright (c) 2021-2022 Ctrl IQ, Inc. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors
// may be used to endorse or promote products derived from this software without
// specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
package modulemd
type V3 struct {
Document string `yaml:"document,omitempty"`
Version int `yaml:"version,omitempty"`
Data *V3Data `yaml:"data,omitempty"`
}
type Configurations struct {
Context string `yaml:"context,omitempty"`
Platform string `yaml:"platform,omitempty"`
BuildRequires map[string][]string `yaml:"buildrequires,omitempty"`
Requires map[string][]string `yaml:"requires,omitempty"`
BuildOpts *BuildOpts `yaml:"buildopts,omitempty"`
}
type V3Data struct {
Name string `yaml:"name,omitempty"`
Stream string `yaml:"stream,omitempty"`
Summary string `yaml:"summary,omitempty"`
Description string `yaml:"description,omitempty"`
License []string `yaml:"license,omitempty"`
Xmd map[string]map[string]string `yaml:"xmd,omitempty"`
Configurations []*Configurations `yaml:"configurations,omitempty"`
References *References `yaml:"references,omitempty"`
Profiles map[string]*Profile `yaml:"profiles,omitempty"`
Profile map[string]*Profile `yaml:"profile,omitempty"`
API *API `yaml:"api,omitempty"`
Filter *API `yaml:"filter,omitempty"`
Demodularized *API `yaml:"demodularized,omitempty"`
Components *Components `yaml:"components,omitempty"`
}

View File

@ -19,7 +19,6 @@ go_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//apollo/rpmutils", "//apollo/rpmutils",
"//modulemd",
"//peridot/composetools", "//peridot/composetools",
"//peridot/db", "//peridot/db",
"//peridot/db/models", "//peridot/db/models",
@ -47,6 +46,7 @@ go_library(
"//vendor/github.com/go-git/go-git/v5/storage/memory", "//vendor/github.com/go-git/go-git/v5/storage/memory",
"//vendor/github.com/gobwas/glob", "//vendor/github.com/gobwas/glob",
"//vendor/github.com/google/uuid", "//vendor/github.com/google/uuid",
"//vendor/github.com/rocky-linux/srpmproc/modulemd",
"//vendor/github.com/rocky-linux/srpmproc/pb", "//vendor/github.com/rocky-linux/srpmproc/pb",
"//vendor/github.com/rocky-linux/srpmproc/pkg/data", "//vendor/github.com/rocky-linux/srpmproc/pkg/data",
"//vendor/github.com/rocky-linux/srpmproc/pkg/srpmproc", "//vendor/github.com/rocky-linux/srpmproc/pkg/srpmproc",

View File

@ -455,6 +455,10 @@ func (c *Controller) mockConfig(project *models.Project, packageVersion *models.
} }
buildMacros := c.buildMacros(project, packageVersion) buildMacros := c.buildMacros(project, packageVersion)
if extra != nil && extra.ForceDist != "" {
buildMacros["%dist"] = "." + extra.ForceDist
}
mockConfig := ` mockConfig := `
config_opts['root'] = '{additionalVendor}-{majorVersion}-{hostArch}' config_opts['root'] = '{additionalVendor}-{majorVersion}-{hostArch}'
config_opts['target_arch'] = '{arch}' config_opts['target_arch'] = '{arch}'

View File

@ -40,6 +40,7 @@ import (
"github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/storage/memory" "github.com/go-git/go-git/v5/storage/memory"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rocky-linux/srpmproc/modulemd"
"go.temporal.io/sdk/workflow" "go.temporal.io/sdk/workflow"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -48,7 +49,6 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb" "google.golang.org/protobuf/types/known/wrapperspb"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"io/ioutil" "io/ioutil"
"peridot.resf.org/modulemd"
"peridot.resf.org/peridot/composetools" "peridot.resf.org/peridot/composetools"
"peridot.resf.org/peridot/db/models" "peridot.resf.org/peridot/db/models"
peridotpb "peridot.resf.org/peridot/pb" peridotpb "peridot.resf.org/peridot/pb"
@ -223,13 +223,18 @@ func (c *Controller) BuildModuleWorkflow(ctx workflow.Context, req *peridotpb.Su
return nil, err return nil, err
} }
branchIndex := map[string]bool{}
var streamRevisions models.ImportRevisions var streamRevisions models.ImportRevisions
for _, revision := range importRevisions { for _, revision := range importRevisions {
if revision.Modular { if revision.Modular {
if len(req.Branches) > 0 && !utils.StrContains(revision.ScmBranchName, req.Branches) { if len(req.Branches) > 0 && !utils.StrContains(revision.ScmBranchName, req.Branches) {
continue continue
} }
if branchIndex[revision.ScmBranchName] {
continue
}
streamRevisions = append(streamRevisions, revision) streamRevisions = append(streamRevisions, revision)
branchIndex[revision.ScmBranchName] = true
} }
} }
@ -247,8 +252,14 @@ func (c *Controller) BuildModuleWorkflow(ctx workflow.Context, req *peridotpb.Su
repoUrl := fmt.Sprintf("%s/modules/%s", upstreamPrefix, gitlabify(pkg.Name)) repoUrl := fmt.Sprintf("%s/modules/%s", upstreamPrefix, gitlabify(pkg.Name))
authenticator, err := c.getAuthenticator(req.ProjectId)
if err != nil {
setInternalError(errorDetails, err)
return nil, err
}
r, err := git.Clone(storer, worktree, &git.CloneOptions{ r, err := git.Clone(storer, worktree, &git.CloneOptions{
URL: repoUrl, URL: repoUrl,
Auth: authenticator,
}) })
if err != nil { if err != nil {
newErr := fmt.Errorf("failed to clone module repo: %s", err) newErr := fmt.Errorf("failed to clone module repo: %s", err)
@ -291,13 +302,61 @@ func (c *Controller) BuildModuleWorkflow(ctx workflow.Context, req *peridotpb.Su
} }
// Parse yaml content to module metadata // Parse yaml content to module metadata
moduleMd, err := modulemd.Parse(yamlContent) moduleMdNotBackwardsCompatible, err := modulemd.Parse(yamlContent)
if err != nil { if err != nil {
newErr := fmt.Errorf("could not parse yaml file from modules repo in branch %s: %v", revision.ScmBranchName, err) newErr := fmt.Errorf("could not parse yaml file from modules repo in branch %s: %v", revision.ScmBranchName, err)
setActivityError(errorDetails, newErr) setActivityError(errorDetails, newErr)
return nil, newErr return nil, newErr
} }
var moduleMd *modulemd.ModuleMd
if moduleMdNotBackwardsCompatible.V2 != nil {
moduleMd = moduleMdNotBackwardsCompatible.V2
} else if moduleMdNotBackwardsCompatible.V3 != nil {
v3 := moduleMdNotBackwardsCompatible.V3
moduleMd = &modulemd.ModuleMd{
Document: "modulemd",
Version: 2,
Data: &modulemd.Data{
Name: v3.Data.Name,
Stream: v3.Data.Stream,
Summary: v3.Data.Summary,
Description: v3.Data.Description,
ServiceLevels: nil,
License: &modulemd.License{
Module: v3.Data.License,
},
Xmd: v3.Data.Xmd,
References: v3.Data.References,
Profiles: v3.Data.Profiles,
Profile: v3.Data.Profile,
API: v3.Data.API,
Filter: v3.Data.Filter,
BuildOpts: nil,
Components: v3.Data.Components,
Artifacts: nil,
},
}
if len(v3.Data.Configurations) > 0 {
cfg := v3.Data.Configurations[0]
if cfg.BuildOpts != nil {
moduleMd.Data.BuildOpts = &modulemd.BuildOpts{
Rpms: cfg.BuildOpts.Rpms,
Arches: cfg.BuildOpts.Arches,
}
moduleMd.Data.Dependencies = []*modulemd.Dependencies{
{
BuildRequires: cfg.BuildRequires,
Requires: cfg.Requires,
},
}
}
}
}
if moduleMd.Data.Name == "" {
moduleMd.Data.Name = pkg.Name
}
// Invalid modulemd in repo // Invalid modulemd in repo
if moduleMd.Data == nil || moduleMd.Data.Components == nil { if moduleMd.Data == nil || moduleMd.Data.Components == nil {
setActivityError(errorDetails, ErrInvalidModule) setActivityError(errorDetails, ErrInvalidModule)
@ -526,6 +585,7 @@ func (c *Controller) BuildModuleStreamWorkflow(ctx workflow.Context, req *perido
ExtraYumrepofsRepos: extraRepos, ExtraYumrepofsRepos: extraRepos,
BuildBatchId: streamBuildOptions.BuildBatchId, BuildBatchId: streamBuildOptions.BuildBatchId,
Modules: buildRequiresModules, Modules: buildRequiresModules,
ForceDist: streamBuildOptions.Dist,
} }
task, err := c.db.CreateTask(nil, "noarch", peridotpb.TaskType_TASK_TYPE_BUILD, &req.ProjectId, &parentTaskId) task, err := c.db.CreateTask(nil, "noarch", peridotpb.TaskType_TASK_TYPE_BUILD, &req.ProjectId, &parentTaskId)

View File

@ -179,26 +179,8 @@ func (c *Controller) RpmLookasideBatchImportWorkflow(ctx workflow.Context, req *
defer cleanupWorker() defer cleanupWorker()
taskID := task.ID.String() taskID := task.ID.String()
var importResults []*RpmImportActivityTaskStage1 var stage1 *RpmImportActivityTaskStage1
var taskIDs []string
taskIDBuildMap := map[string]*RpmImportActivityTaskStage1{}
for _, blob := range req.LookasideBlobs { for _, blob := range req.LookasideBlobs {
var archTask models.Task
archTaskEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} {
newTask, err := c.db.CreateTask(nil, "noarch", peridotpb.TaskType_TASK_TYPE_RPM_IMPORT, &req.ProjectId, &taskID)
if err != nil {
return &models.Task{}
}
_ = c.db.SetTaskStatus(newTask.ID.String(), peridotpb.TaskStatus_TASK_STATUS_RUNNING)
return newTask
})
err := archTaskEffect.Get(&archTask)
if err != nil || !archTask.ProjectId.Valid {
return nil, fmt.Errorf("failed to create rpm task: %s", err)
}
taskIDs = append(taskIDs, archTask.ID.String())
var importRes RpmImportActivityTaskStage1 var importRes RpmImportActivityTaskStage1
importCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ importCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: time.Hour, StartToCloseTimeout: time.Hour,
@ -213,42 +195,40 @@ func (c *Controller) RpmLookasideBatchImportWorkflow(ctx workflow.Context, req *
Rpms: blob, Rpms: blob,
ForceOverride: req.ForceOverride, ForceOverride: req.ForceOverride,
} }
err = workflow.ExecuteActivity(importCtx, c.RpmImportActivity, blobReq, archTask.ID.String(), true).Get(ctx, &importRes) err = workflow.ExecuteActivity(importCtx, c.RpmImportActivity, blobReq, task.ID.String(), true, stage1).Get(ctx, &importRes)
if err != nil { if err != nil {
setActivityError(errorDetails, err) setActivityError(errorDetails, err)
return nil, err return nil, err
} }
importResults = append(importResults, &importRes) if stage1 == nil {
taskIDBuildMap[archTask.ID.String()] = &importRes stage1 = &importRes
}
} }
var res []*RpmImportUploadWrapper var res []*RpmImportUploadWrapper
for _, importTaskID := range taskIDs { uploadArchCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
uploadArchCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ ScheduleToStartTimeout: 12 * time.Hour,
ScheduleToStartTimeout: 12 * time.Hour, StartToCloseTimeout: 24 * time.Hour,
StartToCloseTimeout: 24 * time.Hour, HeartbeatTimeout: 2 * time.Minute,
HeartbeatTimeout: 2 * time.Minute, TaskQueue: importTaskQueue,
TaskQueue: importTaskQueue, })
})
var interimRes []*UploadActivityResult var interimRes []*UploadActivityResult
err = workflow.ExecuteActivity(uploadArchCtx, c.UploadArchActivity, req.ProjectId, importTaskID).Get(ctx, &interimRes) err = workflow.ExecuteActivity(uploadArchCtx, c.UploadArchActivity, req.ProjectId, task.ID.String()).Get(ctx, &interimRes)
if err != nil { if err != nil {
setActivityError(errorDetails, err) setActivityError(errorDetails, err)
return nil, err return nil, err
} }
for _, ires := range interimRes { for _, ires := range interimRes {
res = append(res, &RpmImportUploadWrapper{ res = append(res, &RpmImportUploadWrapper{
Upload: ires, Upload: ires,
TaskID: importTaskID, TaskID: task.ID.String(),
}) })
}
} }
for _, result := range res { for _, result := range res {
stage1 := taskIDBuildMap[result.TaskID] if result.Upload.Skip {
if stage1 == nil { continue
return nil, fmt.Errorf("failed to find task %s", result.TaskID)
} }
err = c.db.AttachTaskToBuild(stage1.Build.ID.String(), result.Upload.Subtask.ID.String()) err = c.db.AttachTaskToBuild(stage1.Build.ID.String(), result.Upload.Subtask.ID.String())
if err != nil { if err != nil {
@ -256,9 +236,6 @@ func (c *Controller) RpmLookasideBatchImportWorkflow(ctx workflow.Context, req *
setInternalError(errorDetails, err) setInternalError(errorDetails, err)
return nil, err return nil, err
} }
if result.Upload.Skip {
continue
}
} }
yumrepoCtx := workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{ yumrepoCtx := workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{
@ -266,14 +243,11 @@ func (c *Controller) RpmLookasideBatchImportWorkflow(ctx workflow.Context, req *
}) })
updateRepoRequest := &UpdateRepoRequest{ updateRepoRequest := &UpdateRepoRequest{
ProjectID: req.ProjectId, ProjectID: req.ProjectId,
BuildIDs: []string{}, BuildIDs: []string{stage1.Build.ID.String()},
Delete: false, Delete: false,
TaskID: &taskID, TaskID: &taskID,
NoDeletePrevious: true, NoDeletePrevious: true,
} }
for _, importRes := range importResults {
updateRepoRequest.BuildIDs = append(updateRepoRequest.BuildIDs, importRes.Build.ID.String())
}
updateRepoTask := &yumrepofspb.UpdateRepoTask{} updateRepoTask := &yumrepofspb.UpdateRepoTask{}
err = workflow.ExecuteChildWorkflow(yumrepoCtx, c.RepoUpdaterWorkflow, updateRepoRequest).Get(yumrepoCtx, updateRepoTask) err = workflow.ExecuteChildWorkflow(yumrepoCtx, c.RepoUpdaterWorkflow, updateRepoRequest).Get(yumrepoCtx, updateRepoTask)
if err != nil { if err != nil {
@ -287,7 +261,7 @@ func (c *Controller) RpmLookasideBatchImportWorkflow(ctx workflow.Context, req *
return &ret, nil return &ret, nil
} }
func (c *Controller) RpmImportActivity(ctx context.Context, req *peridotpb.RpmImportRequest, taskID string, setTaskStatus bool) (*RpmImportActivityTaskStage1, error) { func (c *Controller) RpmImportActivity(ctx context.Context, req *peridotpb.RpmImportRequest, taskID string, setTaskStatus bool, stage1 *RpmImportActivityTaskStage1) (*RpmImportActivityTaskStage1, error) {
go func() { go func() {
for { for {
activity.RecordHeartbeat(ctx) activity.RecordHeartbeat(ctx)
@ -425,33 +399,38 @@ func (c *Controller) RpmImportActivity(ctx context.Context, req *peridotpb.RpmIm
return nil, status.Error(codes.Internal, "could not set task metadata") return nil, status.Error(codes.Internal, "could not set task metadata")
} }
var packageVersionId string var build *models.Build
packageVersionId, err = tx.GetPackageVersionId(pkg.ID.String(), nvrMatch[2], nvrMatch[3]) if stage1 == nil {
if err != nil { var packageVersionId string
if err == sql.ErrNoRows { packageVersionId, err = tx.GetPackageVersionId(pkg.ID.String(), nvrMatch[2], nvrMatch[3])
packageVersionId, err = tx.CreatePackageVersion(pkg.ID.String(), nvrMatch[2], nvrMatch[3]) if err != nil {
if err != nil { if err == sql.ErrNoRows {
err = status.Errorf(codes.Internal, "could not create package version: %v", err) packageVersionId, err = tx.CreatePackageVersion(pkg.ID.String(), nvrMatch[2], nvrMatch[3])
if err != nil {
err = status.Errorf(codes.Internal, "could not create package version: %v", err)
return nil, err
}
} else {
err = status.Errorf(codes.Internal, "could not get package version id: %v", err)
return nil, err return nil, err
} }
} else { }
err = status.Errorf(codes.Internal, "could not get package version id: %v", err)
// todo(mustafa): Add published check, as well as limitations for overriding existing versions
// TODO URGENT: Don't allow nondeterministic behavior regarding versions
err = tx.AttachPackageVersion(req.ProjectId, pkg.ID.String(), packageVersionId, false)
if err != nil {
err = status.Errorf(codes.Internal, "could not attach package version: %v", err)
return nil, err return nil, err
} }
}
// todo(mustafa): Add published check, as well as limitations for overriding existing versions build, err = tx.CreateBuild(pkg.ID.String(), packageVersionId, taskID, req.ProjectId)
// TODO URGENT: Don't allow nondeterministic behavior regarding versions if err != nil {
err = tx.AttachPackageVersion(req.ProjectId, pkg.ID.String(), packageVersionId, false) err = status.Errorf(codes.Internal, "could not create build")
if err != nil { return nil, err
err = status.Errorf(codes.Internal, "could not attach package version: %v", err) }
return nil, err } else {
} build = stage1.Build
build, err := tx.CreateBuild(pkg.ID.String(), packageVersionId, taskID, req.ProjectId)
if err != nil {
err = status.Errorf(codes.Internal, "could not create build")
return nil, err
} }
targetDir := filepath.Join(rpmbuild.GetCloneDirectory(), "RPMS") targetDir := filepath.Join(rpmbuild.GetCloneDirectory(), "RPMS")

View File

@ -195,6 +195,13 @@ func kindCatalogSync(tx peridotdb.Access, req *peridotpb.SyncCatalogRequest, cat
// perl.aarch64 -> perl // perl.aarch64 -> perl
nvrIndex := map[string]string{} nvrIndex := map[string]string{}
for _, catalog := range catalogs { for _, catalog := range catalogs {
if catalog.ModuleConfiguration != nil {
if ret.ModuleConfiguration != nil {
return nil, fmt.Errorf("multiple module configurations found")
}
ret.ModuleConfiguration = catalog.ModuleConfiguration
}
for _, pkg := range catalog.Package { for _, pkg := range catalog.Package {
for _, repo := range pkg.Repository { for _, repo := range pkg.Repository {
if repoIndex[repo.Name] == nil { if repoIndex[repo.Name] == nil {
@ -222,6 +229,22 @@ func kindCatalogSync(tx peridotdb.Access, req *peridotpb.SyncCatalogRequest, cat
Type: pkg.Type, Type: pkg.Type,
}) })
} }
for _, moduleStream := range repo.ModuleStream {
modulePkg := fmt.Sprintf("module:%s:%s", pkg.Name, moduleStream)
alreadyExists := false
for _, p := range repoIndex[repo.Name].Packages {
if p.Name == modulePkg {
alreadyExists = true
break
}
}
if !alreadyExists {
repoIndex[repo.Name].Packages = append(repoIndex[repo.Name].Packages, RepoSyncPackage{
Name: modulePkg,
Type: pkg.Type,
})
}
}
for _, inf := range repo.IncludeFilter { for _, inf := range repo.IncludeFilter {
nvrIndex[inf] = pkg.Name nvrIndex[inf] = pkg.Name
if repoIndex[repo.Name].IncludeFilter[pkg.Name] == nil { if repoIndex[repo.Name].IncludeFilter[pkg.Name] == nil {
@ -321,6 +344,16 @@ func kindCatalogSync(tx peridotdb.Access, req *peridotpb.SyncCatalogRequest, cat
for _, repo := range repoIndex { for _, repo := range repoIndex {
for _, pkg := range repo.Packages { for _, pkg := range repo.Packages {
// Skip if it starts with module: as it's a module stream
if strings.HasPrefix(pkg.Name, "module:") {
continue
}
// Always refresh type, expensive but necessary
if err := tx.SetPackageType(req.ProjectId.Value, pkg.Name, pkg.Type); err != nil {
return nil, fmt.Errorf("failed to update package type: %w", err)
}
// Skip if already in project // Skip if already in project
if packageExistsIndex[pkg.Name] { if packageExistsIndex[pkg.Name] {
continue continue
@ -857,6 +890,14 @@ func (c *Controller) SyncCatalogActivity(req *peridotpb.SyncCatalogRequest) (*pe
} }
ret.CatalogSync = resKindCatalogSync ret.CatalogSync = resKindCatalogSync
// Set module configuration if it exists
if resKindCatalogSync.ModuleConfiguration != nil {
err := tx.CreateProjectModuleConfiguration(req.ProjectId.Value, resKindCatalogSync.ModuleConfiguration)
if err != nil {
return nil, fmt.Errorf("failed to create project module configuration: %w", err)
}
}
// Check if we have comps // Check if we have comps
err = checkApplyComps(w, tx, req.ProjectId.Value) err = checkApplyComps(w, tx, req.ProjectId.Value)
if err != nil { if err != nil {
@ -878,6 +919,10 @@ func (c *Controller) SyncCatalogActivity(req *peridotpb.SyncCatalogRequest) (*pe
var buildIDs []string var buildIDs []string
var newBuildPackages []string var newBuildPackages []string
for _, newPackage := range ret.CatalogSync.NewPackages { for _, newPackage := range ret.CatalogSync.NewPackages {
// Skip module streams
if strings.Contains(newPackage, "module:") {
continue
}
if utils.StrContains(newPackage, newBuildPackages) { if utils.StrContains(newPackage, newBuildPackages) {
continue continue
} }
@ -897,6 +942,10 @@ func (c *Controller) SyncCatalogActivity(req *peridotpb.SyncCatalogRequest) (*pe
newBuildPackages = append(newBuildPackages, newPackage) newBuildPackages = append(newBuildPackages, newPackage)
} }
for _, newPackage := range ret.CatalogSync.ModifiedPackages { for _, newPackage := range ret.CatalogSync.ModifiedPackages {
// Skip module streams
if strings.Contains(newPackage, "module:") {
continue
}
if utils.StrContains(newPackage, newBuildPackages) { if utils.StrContains(newPackage, newBuildPackages) {
continue continue
} }

View File

@ -43,6 +43,7 @@ import (
"fmt" "fmt"
"github.com/gobwas/glob" "github.com/gobwas/glob"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rocky-linux/srpmproc/modulemd"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.temporal.io/sdk/activity" "go.temporal.io/sdk/activity"
"go.temporal.io/sdk/temporal" "go.temporal.io/sdk/temporal"
@ -56,7 +57,6 @@ import (
"io/ioutil" "io/ioutil"
"path/filepath" "path/filepath"
"peridot.resf.org/apollo/rpmutils" "peridot.resf.org/apollo/rpmutils"
"peridot.resf.org/modulemd"
"peridot.resf.org/peridot/composetools" "peridot.resf.org/peridot/composetools"
peridotdb "peridot.resf.org/peridot/db" peridotdb "peridot.resf.org/peridot/db"
"peridot.resf.org/peridot/db/models" "peridot.resf.org/peridot/db/models"
@ -539,9 +539,9 @@ func (c *Controller) UpdateRepoActivity(ctx context.Context, req *UpdateRepoRequ
} }
lockedItem, err = lock.AcquireLock( lockedItem, err = lock.AcquireLock(
req.ProjectID, req.ProjectID,
dynamolock.ReplaceData(),
) )
if err != nil { if err != nil {
c.log.Errorf("failed to acquire lock: %v", err)
continue continue
} }
break break
@ -978,7 +978,7 @@ func (c *Controller) makeRepoChanges(tx peridotdb.Access, req *UpdateRepoRequest
var currentActiveArtifacts models.TaskArtifacts var currentActiveArtifacts models.TaskArtifacts
// Get currently active artifacts // Get currently active artifacts
latestBuilds, err := c.db.GetLatestBuildIdsByPackageName(build.PackageName, project.ID.String()) latestBuilds, err := c.db.GetLatestBuildsByPackageNameAndPackageVersionID(build.PackageName, build.PackageVersionId, project.ID.String())
if err != nil { if err != nil {
setInternalError(errorDetails, err) setInternalError(errorDetails, err)
return nil, fmt.Errorf("failed to get latest build ids: %v", err) return nil, fmt.Errorf("failed to get latest build ids: %v", err)
@ -1435,10 +1435,11 @@ func (c *Controller) makeRepoChanges(tx peridotdb.Access, req *UpdateRepoRequest
if moduleStream != nil { if moduleStream != nil {
streamDocument := moduleStream.ModuleStreamDocuments[arch] streamDocument := moduleStream.ModuleStreamDocuments[arch]
if streamDocument != nil { if streamDocument != nil {
newEntry, err := modulemd.Parse(streamDocument.Streams[moduleStream.Stream]) newEntryNbc, err := modulemd.Parse(streamDocument.Streams[moduleStream.Stream])
if err != nil { if err != nil {
return nil, err return nil, err
} }
newEntry := newEntryNbc.V2
// If a previous entry exists, we need to overwrite that // If a previous entry exists, we need to overwrite that
var moduleIndex *int var moduleIndex *int

View File

@ -60,7 +60,7 @@ func init() {
} }
func mn(_ *cobra.Command, _ []string) { func mn(_ *cobra.Command, _ []string) {
sess, err := utils.NewAwsSession(&aws.Config{}) sess, err := utils.NewAwsSessionNoLocalStack(&aws.Config{})
if err != nil { if err != nil {
logrus.Fatal(err) logrus.Fatal(err)
} }

View File

@ -42,6 +42,7 @@ type Access interface {
ListProjects(filters *peridotpb.ProjectFilters) (models.Projects, error) ListProjects(filters *peridotpb.ProjectFilters) (models.Projects, error)
GetProjectKeys(projectId string) (*models.ProjectKey, error) GetProjectKeys(projectId string) (*models.ProjectKey, error)
GetProjectModuleConfiguration(projectId string) (*peridotpb.ModuleConfiguration, error) GetProjectModuleConfiguration(projectId string) (*peridotpb.ModuleConfiguration, error)
CreateProjectModuleConfiguration(projectId string, config *peridotpb.ModuleConfiguration) error
CreateProject(project *peridotpb.Project) (*models.Project, error) CreateProject(project *peridotpb.Project) (*models.Project, error)
UpdateProject(id string, project *peridotpb.Project) (*models.Project, error) UpdateProject(id string, project *peridotpb.Project) (*models.Project, error)
SetProjectKeys(projectId string, username string, password string) error SetProjectKeys(projectId string, username string, password string) error
@ -65,6 +66,7 @@ type Access interface {
NVRAExists(nvra string) (bool, error) NVRAExists(nvra string) (bool, error)
GetBuildByPackageNameAndVersionAndRelease(name string, version string, release string, projectId string) (*models.Build, error) GetBuildByPackageNameAndVersionAndRelease(name string, version string, release string, projectId string) (*models.Build, error)
GetLatestBuildIdsByPackageName(name string, projectId string) ([]string, error) GetLatestBuildIdsByPackageName(name string, projectId string) ([]string, error)
GetLatestBuildsByPackageNameAndPackageVersionID(name string, packageVersionId string, projectId string) ([]string, error)
GetActiveBuildIdsByTaskArtifactGlob(taskArtifactGlob string, projectId string) ([]string, error) GetActiveBuildIdsByTaskArtifactGlob(taskArtifactGlob string, projectId string) ([]string, error)
GetAllBuildIdsByPackageName(name string, projectId string) ([]string, error) GetAllBuildIdsByPackageName(name string, projectId string) ([]string, error)
@ -98,6 +100,7 @@ type Access interface {
SetExtraOptionsForPackage(projectId string, packageName string, withFlags pq.StringArray, withoutFlags pq.StringArray) error SetExtraOptionsForPackage(projectId string, packageName string, withFlags pq.StringArray, withoutFlags pq.StringArray) error
GetExtraOptionsForPackage(projectId string, packageName string) (*models.ExtraOptions, error) GetExtraOptionsForPackage(projectId string, packageName string) (*models.ExtraOptions, error)
SetGroupInstallOptionsForPackage(projectId string, packageName string, dependsOn pq.StringArray) error SetGroupInstallOptionsForPackage(projectId string, packageName string, dependsOn pq.StringArray) error
SetPackageType(projectId string, packageName string, packageType peridotpb.PackageType) error
CreateTask(user *utils.ContextUser, arch string, taskType peridotpb.TaskType, projectId *string, parentTaskId *string) (*models.Task, error) CreateTask(user *utils.ContextUser, arch string, taskType peridotpb.TaskType, projectId *string, parentTaskId *string) (*models.Task, error)
SetTaskStatus(id string, status peridotpb.TaskStatus) error SetTaskStatus(id string, status peridotpb.TaskStatus) error

View File

@ -428,6 +428,37 @@ func (a *Access) GetLatestBuildIdsByPackageName(name string, projectId string) (
return ret, nil return ret, nil
} }
func (a *Access) GetLatestBuildsByPackageNameAndPackageVersionID(name string, packageVersionId string, projectId string) ([]string, error) {
var ret []string
err := a.query.Select(
&ret,
`
select
b.id
from builds b
inner join tasks t on t.id = b.task_id
inner join packages p on p.id = b.package_id
inner join project_package_versions ppv on ppv.package_version_id = b.package_version_id
where
b.project_id = $1
and p.name = $2
and ppv.active_in_repo = true
and ppv.project_id = b.project_id
and b.package_version_id = $3
and t.status = 3
order by b.created_at asc
`,
projectId,
name,
packageVersionId,
)
if err != nil {
return nil, err
}
return ret, nil
}
func (a *Access) GetActiveBuildIdsByTaskArtifactGlob(taskArtifactGlob string, projectId string) ([]string, error) { func (a *Access) GetActiveBuildIdsByTaskArtifactGlob(taskArtifactGlob string, projectId string) ([]string, error) {
var ret []string var ret []string
err := a.query.Select( err := a.query.Select(

View File

@ -362,3 +362,16 @@ func (a *Access) SetGroupInstallOptionsForPackage(projectId string, packageName
) )
return err return err
} }
func (a *Access) SetPackageType(projectId string, packageName string, packageType peridotpb.PackageType) error {
_, err := a.query.Exec(
`
update project_packages set package_type_override = $3
where project_id = $1 and package_id = (select id from packages where name = $2)
`,
projectId,
packageName,
packageType,
)
return err
}

View File

@ -150,6 +150,34 @@ func (a *Access) GetProjectModuleConfiguration(projectId string) (*peridotpb.Mod
return pb, nil return pb, nil
} }
func (a *Access) CreateProjectModuleConfiguration(projectId string, config *peridotpb.ModuleConfiguration) error {
anyPb, err := anypb.New(config)
if err != nil {
return fmt.Errorf("failed to marshal module configuration: %v", err)
}
protoJson, err := protojson.Marshal(anyPb)
if err != nil {
return fmt.Errorf("failed to marshal module configuration (protojson): %v", err)
}
_, err = a.query.Exec(
`
insert into project_module_configuration (project_id, proto, active)
values ($1, $2, true)
on conflict (project_id) do update
set proto = $2, active = true
`,
projectId,
protoJson,
)
if err != nil {
return err
}
return nil
}
func (a *Access) CreateProject(project *peridotpb.Project) (*models.Project, error) { func (a *Access) CreateProject(project *peridotpb.Project) (*models.Project, error) {
if err := project.ValidateAll(); err != nil { if err := project.ValidateAll(); err != nil {
return nil, err return nil, err

View File

@ -316,23 +316,11 @@ func (s *Server) SubmitBuild(ctx context.Context, req *peridotpb.SubmitBuildRequ
return nil, errors.New("could not find upstream branch") return nil, errors.New("could not find upstream branch")
} }
build, err := tx.CreateBuild(pkg.ID.String(), importRevision.PackageVersionId, task.ID.String(), req.ProjectId)
if err != nil {
s.log.Errorf("could not create build: %v", err)
return nil, status.Error(codes.InvalidArgument, "could not create build")
}
taskProto, err := task.ToProto(true) taskProto, err := task.ToProto(true)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal task: %v", err) return nil, status.Errorf(codes.Internal, "could not marshal task: %v", err)
} }
rollback = false
err = beginTx.Commit()
if err != nil {
return nil, status.Error(codes.Internal, "could not save, try again")
}
// Check if all branches are modular (that means it's only a module component/module) // Check if all branches are modular (that means it's only a module component/module)
allStream := true allStream := true
for _, revision := range revisions { for _, revision := range revisions {
@ -346,6 +334,12 @@ func (s *Server) SubmitBuild(ctx context.Context, req *peridotpb.SubmitBuildRequ
} }
if (packageType == peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK || packageType == peridotpb.PackageType_PACKAGE_TYPE_NORMAL_FORK_MODULE || packageType == peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK_MODULE_COMPONENT) && req.ModuleVariant { if (packageType == peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK || packageType == peridotpb.PackageType_PACKAGE_TYPE_NORMAL_FORK_MODULE || packageType == peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK_MODULE_COMPONENT) && req.ModuleVariant {
rollback = false
err = beginTx.Commit()
if err != nil {
return nil, status.Error(codes.Internal, "could not save, try again")
}
_, err = s.temporal.ExecuteWorkflow( _, err = s.temporal.ExecuteWorkflow(
context.Background(), context.Background(),
client.StartWorkflowOptions{ client.StartWorkflowOptions{
@ -356,9 +350,7 @@ func (s *Server) SubmitBuild(ctx context.Context, req *peridotpb.SubmitBuildRequ
s.temporalWorker.WorkflowController.BuildModuleWorkflow, s.temporalWorker.WorkflowController.BuildModuleWorkflow,
req, req,
task, task,
&peridotpb.ExtraBuildOptions{ &peridotpb.ExtraBuildOptions{},
ReusableBuildId: build.ID.String(),
},
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -366,6 +358,18 @@ func (s *Server) SubmitBuild(ctx context.Context, req *peridotpb.SubmitBuildRequ
} }
if packageType != peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK && packageType != peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK_MODULE_COMPONENT && len(req.Branches) == 0 && !allStream && !req.ModuleVariant { if packageType != peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK && packageType != peridotpb.PackageType_PACKAGE_TYPE_MODULE_FORK_MODULE_COMPONENT && len(req.Branches) == 0 && !allStream && !req.ModuleVariant {
build, err := tx.CreateBuild(pkg.ID.String(), importRevision.PackageVersionId, task.ID.String(), req.ProjectId)
if err != nil {
s.log.Errorf("could not create build: %v", err)
return nil, status.Error(codes.InvalidArgument, "could not create build")
}
rollback = false
err = beginTx.Commit()
if err != nil {
return nil, status.Error(codes.Internal, "could not save, try again")
}
_, err = s.temporal.ExecuteWorkflow( _, err = s.temporal.ExecuteWorkflow(
context.Background(), context.Background(),
client.StartWorkflowOptions{ client.StartWorkflowOptions{

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) All respective contributors to the Peridot Project. All rights reserved.
* Copyright (c) 2021-2022 Rocky Enterprise Software Foundation, Inc. All rights reserved.
* Copyright (c) 2021-2022 Ctrl IQ, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its contributors
* may be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
alter table project_module_configuration drop constraint project_id_uniq;

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) All respective contributors to the Peridot Project. All rights reserved.
* Copyright (c) 2021-2022 Rocky Enterprise Software Foundation, Inc. All rights reserved.
* Copyright (c) 2021-2022 Ctrl IQ, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its contributors
* may be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
alter table project_module_configuration add constraint project_id_uniq unique (project_id);

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) All respective contributors to the Peridot Project. All rights reserved.
* Copyright (c) 2021-2022 Rocky Enterprise Software Foundation, Inc. All rights reserved.
* Copyright (c) 2021-2022 Ctrl IQ, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its contributors
* may be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
alter table builds add constraint builds_task_id_package_id unique (task_id, package_id);

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) All respective contributors to the Peridot Project. All rights reserved.
* Copyright (c) 2021-2022 Rocky Enterprise Software Foundation, Inc. All rights reserved.
* Copyright (c) 2021-2022 Ctrl IQ, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its contributors
* may be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
alter table builds drop constraint builds_task_id_package_id;

View File

@ -328,6 +328,9 @@ message ExtraBuildOptions {
// Whether to enable networking in rpmbuild // Whether to enable networking in rpmbuild
bool enable_networking = 8; bool enable_networking = 8;
// Force a specific dist
string force_dist = 9;
} }
message RpmImportRequest { message RpmImportRequest {

View File

@ -8,6 +8,7 @@ import "validate/validate.proto";
import "google/api/annotations.proto"; import "google/api/annotations.proto";
import "peridot/proto/v1/task.proto"; import "peridot/proto/v1/task.proto";
import "peridot/proto/v1/package.proto"; import "peridot/proto/v1/package.proto";
import "peridot/proto/v1/module.proto";
option go_package = "peridot.resf.org/peridot/pb;peridotpb"; option go_package = "peridot.resf.org/peridot/pb;peridotpb";
@ -15,6 +16,7 @@ message CatalogSyncRepository {
string name = 1; string name = 1;
repeated string include_filter = 2; repeated string include_filter = 2;
repeated string multilib = 3; repeated string multilib = 3;
repeated string module_stream = 4;
} }
message CatalogSyncPackage { message CatalogSyncPackage {
@ -47,6 +49,7 @@ message CatalogSync {
repeated string exclude_multilib_filter = 3; repeated string exclude_multilib_filter = 3;
repeated GlobFilter exclude_filter = 4; repeated GlobFilter exclude_filter = 4;
repeated GlobFilter include_filter = 5; repeated GlobFilter include_filter = 5;
resf.peridot.v1.ModuleConfiguration module_configuration = 6;
} }
message CatalogExtraPackageOptions { message CatalogExtraPackageOptions {
@ -82,6 +85,7 @@ message KindCatalogSync {
repeated string new_repositories = 2; repeated string new_repositories = 2;
repeated string modified_repositories = 3; repeated string modified_repositories = 3;
repeated string additional_nvr_globs = 5; repeated string additional_nvr_globs = 5;
resf.peridot.v1.ModuleConfiguration module_configuration = 6;
} }
message KindCatalogExtraOptions { message KindCatalogExtraOptions {

View File

@ -94,6 +94,7 @@ const columns: GridColDef[] = [
<Chip size="small" label="Package" variant="outlined" /> <Chip size="small" label="Package" variant="outlined" />
)} )}
{(params.row['type'] === V1PackageType.ModuleFork || {(params.row['type'] === V1PackageType.ModuleFork ||
params.row['type'] === V1PackageType.NormalForkModule ||
params.row['type'] === V1PackageType.ModuleForkModuleComponent) && ( params.row['type'] === V1PackageType.ModuleForkModuleComponent) && (
<Chip <Chip
size="small" size="small"

View File

@ -53,6 +53,7 @@ def gen_from_jsonnet(name, src, outs, tags, force_normal_tags, helm_mode, **kwar
"domain_user": "{STABLE_DOMAIN_USER}", "domain_user": "{STABLE_DOMAIN_USER}",
"registry_secret": "{STABLE_REGISTRY_SECRET}", "registry_secret": "{STABLE_REGISTRY_SECRET}",
"site": "{STABLE_SITE}", "site": "{STABLE_SITE}",
"local_domain": "{STABLE_LOCAL_DOMAIN}",
"helm_mode": "false", "helm_mode": "false",
} }
if helm_mode: if helm_mode:
@ -84,6 +85,7 @@ def gen_from_jsonnet(name, src, outs, tags, force_normal_tags, helm_mode, **kwar
"domain_user", "domain_user",
"registry_secret", "registry_secret",
"site", "site",
"local_domain",
], ],
multiple_outputs = True, multiple_outputs = True,
extra_args = ["-S"], extra_args = ["-S"],

View File

@ -41,4 +41,5 @@ STABLE_OCI_REGISTRY_DOCKER ${STABLE_OCI_REGISTRY_DOCKER:-docker.io}
STABLE_REGISTRY_SECRET ${STABLE_REGISTRY_SECRET:-none} STABLE_REGISTRY_SECRET ${STABLE_REGISTRY_SECRET:-none}
STABLE_OCI_REGISTRY_NO_NESTED_SUPPORT_IN_2022_SHAME_ON_YOU_AWS ${STABLE_OCI_REGISTRY_NO_NESTED_SUPPORT_IN_2022_SHAME_ON_YOU_AWS:-false} STABLE_OCI_REGISTRY_NO_NESTED_SUPPORT_IN_2022_SHAME_ON_YOU_AWS ${STABLE_OCI_REGISTRY_NO_NESTED_SUPPORT_IN_2022_SHAME_ON_YOU_AWS:-false}
STABLE_SITE ${STABLE_SITE:-normal} STABLE_SITE ${STABLE_SITE:-normal}
STABLE_LOCAL_DOMAIN ${STABLE_LOCAL_DOMAIN:-.pdev.resf.localhost}
EOF EOF

View File

@ -39,7 +39,7 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
) )
func NewAwsSession(cfg *aws.Config) (*session.Session, error) { func awsConfigInternal(cfg *aws.Config) {
if accessKey := viper.GetString("s3-access-key"); accessKey != "" { if accessKey := viper.GetString("s3-access-key"); accessKey != "" {
cfg.Credentials = credentials.NewStaticCredentials(accessKey, viper.GetString("s3-secret-key"), "") cfg.Credentials = credentials.NewStaticCredentials(accessKey, viper.GetString("s3-secret-key"), "")
} }
@ -63,7 +63,16 @@ func NewAwsSession(cfg *aws.Config) (*session.Session, error) {
if os.Getenv("AWS_REGION") != "" { if os.Getenv("AWS_REGION") != "" {
cfg.Region = aws.String(os.Getenv("AWS_REGION")) cfg.Region = aws.String(os.Getenv("AWS_REGION"))
} }
}
func NewAwsSessionNoLocalStack(cfg *aws.Config) (*session.Session, error) {
awsConfigInternal(cfg)
return session.NewSession(cfg)
}
func NewAwsSession(cfg *aws.Config) (*session.Session, error) {
awsConfigInternal(cfg)
if os.Getenv("LOCALSTACK_ENDPOINT") != "" && os.Getenv("RESF_ENV") == "dev" { if os.Getenv("LOCALSTACK_ENDPOINT") != "" && os.Getenv("RESF_ENV") == "dev" {
cfg.Endpoint = aws.String(os.Getenv("LOCALSTACK_ENDPOINT")) cfg.Endpoint = aws.String(os.Getenv("LOCALSTACK_ENDPOINT"))
cfg.Credentials = credentials.NewStaticCredentials("test", "test", "") cfg.Credentials = credentials.NewStaticCredentials("test", "test", "")

View File

@ -50,9 +50,19 @@ func prettify(v reflect.Value, indent int, buf *bytes.Buffer) {
for i, n := range names { for i, n := range names {
val := v.FieldByName(n) val := v.FieldByName(n)
ft, ok := v.Type().FieldByName(n)
if !ok {
panic(fmt.Sprintf("expected to find field %v on type %v, but was not found", n, v.Type()))
}
buf.WriteString(strings.Repeat(" ", indent+2)) buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ") buf.WriteString(n + ": ")
prettify(val, indent+2, buf)
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
prettify(val, indent+2, buf)
}
if i < len(names)-1 { if i < len(names)-1 {
buf.WriteString(",\n") buf.WriteString(",\n")

View File

@ -8,6 +8,8 @@ import (
) )
// StringValue returns the string representation of a value. // StringValue returns the string representation of a value.
//
// Deprecated: Use Prettify instead.
func StringValue(i interface{}) string { func StringValue(i interface{}) string {
var buf bytes.Buffer var buf bytes.Buffer
stringValue(reflect.ValueOf(i), 0, &buf) stringValue(reflect.ValueOf(i), 0, &buf)

View File

@ -10,12 +10,13 @@ import (
// A Config provides configuration to a service client instance. // A Config provides configuration to a service client instance.
type Config struct { type Config struct {
Config *aws.Config Config *aws.Config
Handlers request.Handlers Handlers request.Handlers
PartitionID string PartitionID string
Endpoint string Endpoint string
SigningRegion string SigningRegion string
SigningName string SigningName string
ResolvedRegion string
// States that the signing name did not come from a modeled source but // States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors // was derived based on other data. Used by service client constructors
@ -88,10 +89,6 @@ func (c *Client) NewRequest(operation *request.Operation, params interface{}, da
// AddDebugHandlers injects debug logging handlers into the service to log request // AddDebugHandlers injects debug logging handlers into the service to log request
// debug information. // debug information.
func (c *Client) AddDebugHandlers() { func (c *Client) AddDebugHandlers() {
if !c.Config.LogLevel.AtLeast(aws.LogDebug) {
return
}
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler) c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler) c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
} }

View File

@ -53,6 +53,10 @@ var LogHTTPRequestHandler = request.NamedHandler{
} }
func logRequest(r *request.Request) { func logRequest(r *request.Request) {
if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
bodySeekable := aws.IsReaderSeekable(r.Body) bodySeekable := aws.IsReaderSeekable(r.Body)
@ -90,6 +94,10 @@ var LogHTTPRequestHeaderHandler = request.NamedHandler{
} }
func logRequestHeader(r *request.Request) { func logRequestHeader(r *request.Request) {
if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return
}
b, err := httputil.DumpRequestOut(r.HTTPRequest, false) b, err := httputil.DumpRequestOut(r.HTTPRequest, false)
if err != nil { if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
@ -120,6 +128,10 @@ var LogHTTPResponseHandler = request.NamedHandler{
} }
func logResponse(r *request.Request) { func logResponse(r *request.Request) {
if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return
}
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)} lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
if r.HTTPResponse == nil { if r.HTTPResponse == nil {
@ -178,7 +190,7 @@ var LogHTTPResponseHeaderHandler = request.NamedHandler{
} }
func logResponseHeader(r *request.Request) { func logResponseHeader(r *request.Request) {
if r.Config.Logger == nil { if !r.Config.LogLevel.AtLeast(aws.LogDebug) || r.Config.Logger == nil {
return return
} }

View File

@ -2,13 +2,14 @@ package metadata
// ClientInfo wraps immutable data from the client.Client structure. // ClientInfo wraps immutable data from the client.Client structure.
type ClientInfo struct { type ClientInfo struct {
ServiceName string ServiceName string
ServiceID string ServiceID string
APIVersion string APIVersion string
PartitionID string PartitionID string
Endpoint string Endpoint string
SigningName string SigningName string
SigningRegion string SigningRegion string
JSONVersion string JSONVersion string
TargetPrefix string TargetPrefix string
ResolvedRegion string
} }

View File

@ -170,6 +170,9 @@ type Config struct {
// //
// For example S3's X-Amz-Meta prefixed header will be unmarshaled to lower case // For example S3's X-Amz-Meta prefixed header will be unmarshaled to lower case
// Metadata member's map keys. The value of the header in the map is unaffected. // Metadata member's map keys. The value of the header in the map is unaffected.
//
// The AWS SDK for Go v2, uses lower case header maps by default. The v1
// SDK provides this opt-in for this option, for backwards compatibility.
LowerCaseHeaderMaps *bool LowerCaseHeaderMaps *bool
// Set this to `true` to disable the EC2Metadata client from overriding the // Set this to `true` to disable the EC2Metadata client from overriding the
@ -208,8 +211,19 @@ type Config struct {
// svc := s3.New(sess, &aws.Config{ // svc := s3.New(sess, &aws.Config{
// UseDualStack: aws.Bool(true), // UseDualStack: aws.Bool(true),
// }) // })
//
// Deprecated: This option will continue to function for S3 and S3 Control for backwards compatibility.
// UseDualStackEndpoint should be used to enable usage of a service's dual-stack endpoint for all service clients
// moving forward. For S3 and S3 Control, when UseDualStackEndpoint is set to a non-zero value it takes higher
// precedence then this option.
UseDualStack *bool UseDualStack *bool
// Sets the resolver to resolve a dual-stack endpoint for the service.
UseDualStackEndpoint endpoints.DualStackEndpointState
// UseFIPSEndpoint specifies the resolver must resolve a FIPS endpoint.
UseFIPSEndpoint endpoints.FIPSEndpointState
// SleepDelay is an override for the func the SDK will call when sleeping // SleepDelay is an override for the func the SDK will call when sleeping
// during the lifecycle of a request. Specifically this will be used for // during the lifecycle of a request. Specifically this will be used for
// request delays. This value should only be used for testing. To adjust // request delays. This value should only be used for testing. To adjust
@ -438,13 +452,6 @@ func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
return c return c
} }
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
mergeInConfig(c, other)
}
}
// WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag // WithSTSRegionalEndpoint will set whether or not to use regional endpoint flag
// when resolving the endpoint for a service // when resolving the endpoint for a service
func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config { func (c *Config) WithSTSRegionalEndpoint(sre endpoints.STSRegionalEndpoint) *Config {
@ -459,6 +466,27 @@ func (c *Config) WithS3UsEast1RegionalEndpoint(sre endpoints.S3UsEast1RegionalEn
return c return c
} }
// WithLowerCaseHeaderMaps sets a config LowerCaseHeaderMaps value
// returning a Config pointer for chaining.
func (c *Config) WithLowerCaseHeaderMaps(t bool) *Config {
c.LowerCaseHeaderMaps = &t
return c
}
// WithDisableRestProtocolURICleaning sets a config DisableRestProtocolURICleaning value
// returning a Config pointer for chaining.
func (c *Config) WithDisableRestProtocolURICleaning(t bool) *Config {
c.DisableRestProtocolURICleaning = &t
return c
}
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
mergeInConfig(c, other)
}
}
func mergeInConfig(dst *Config, other *Config) { func mergeInConfig(dst *Config, other *Config) {
if other == nil { if other == nil {
return return
@ -540,6 +568,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.UseDualStack = other.UseDualStack dst.UseDualStack = other.UseDualStack
} }
if other.UseDualStackEndpoint != endpoints.DualStackEndpointStateUnset {
dst.UseDualStackEndpoint = other.UseDualStackEndpoint
}
if other.EC2MetadataDisableTimeoutOverride != nil { if other.EC2MetadataDisableTimeoutOverride != nil {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
} }
@ -571,6 +603,18 @@ func mergeInConfig(dst *Config, other *Config) {
if other.S3UsEast1RegionalEndpoint != endpoints.UnsetS3UsEast1Endpoint { if other.S3UsEast1RegionalEndpoint != endpoints.UnsetS3UsEast1Endpoint {
dst.S3UsEast1RegionalEndpoint = other.S3UsEast1RegionalEndpoint dst.S3UsEast1RegionalEndpoint = other.S3UsEast1RegionalEndpoint
} }
if other.LowerCaseHeaderMaps != nil {
dst.LowerCaseHeaderMaps = other.LowerCaseHeaderMaps
}
if other.UseDualStackEndpoint != endpoints.DualStackEndpointStateUnset {
dst.UseDualStackEndpoint = other.UseDualStackEndpoint
}
if other.UseFIPSEndpoint != endpoints.FIPSEndpointStateUnset {
dst.UseFIPSEndpoint = other.UseFIPSEndpoint
}
} }
// Copy will return a shallow copy of the Config object. If any additional // Copy will return a shallow copy of the Config object. If any additional

View File

@ -1,3 +1,4 @@
//go:build !go1.9
// +build !go1.9 // +build !go1.9
package aws package aws

View File

@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9 // +build go1.9
package aws package aws

View File

@ -1,3 +1,4 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package aws package aws

View File

@ -1,3 +1,4 @@
//go:build go1.7
// +build go1.7 // +build go1.7
package aws package aws

View File

@ -178,7 +178,7 @@ func handleSendError(r *request.Request, err error) {
var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) { var ValidateResponseHandler = request.NamedHandler{Name: "core.ValidateResponseHandler", Fn: func(r *request.Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 { if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler // this may be replaced by an UnmarshalError handler
r.Error = awserr.New("UnknownError", "unknown error", nil) r.Error = awserr.New("UnknownError", "unknown error", r.Error)
} }
}} }}

View File

@ -1,3 +1,4 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package credentials package credentials

View File

@ -1,3 +1,4 @@
//go:build go1.7
// +build go1.7 // +build go1.7
package credentials package credentials

View File

@ -1,3 +1,4 @@
//go:build !go1.9
// +build !go1.9 // +build !go1.9
package credentials package credentials

View File

@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9 // +build go1.9
package credentials package credentials

View File

@ -0,0 +1,22 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "ssocreds",
srcs = [
"doc.go",
"os.go",
"os_windows.go",
"provider.go",
],
importmap = "peridot.resf.org/vendor/github.com/aws/aws-sdk-go/aws/credentials/ssocreds",
importpath = "github.com/aws/aws-sdk-go/aws/credentials/ssocreds",
visibility = ["//visibility:public"],
deps = [
"//vendor/github.com/aws/aws-sdk-go/aws",
"//vendor/github.com/aws/aws-sdk-go/aws/awserr",
"//vendor/github.com/aws/aws-sdk-go/aws/client",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials",
"//vendor/github.com/aws/aws-sdk-go/service/sso",
"//vendor/github.com/aws/aws-sdk-go/service/sso/ssoiface",
],
)

View File

@ -0,0 +1,60 @@
// Package ssocreds provides a credential provider for retrieving temporary AWS credentials using an SSO access token.
//
// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider
// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by
// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in
// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned.
//
// Loading AWS SSO credentials with the AWS shared configuration file
//
// You can use configure AWS SSO credentials from the AWS shared configuration file by
// providing the specifying the required keys in the profile:
//
// sso_account_id
// sso_region
// sso_role_name
// sso_start_url
//
// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target
// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be
// provided, or an error will be returned.
//
// [profile devsso]
// sso_start_url = https://my-sso-portal.awsapps.com/start
// sso_role_name = SSOReadOnlyRole
// sso_region = us-east-1
// sso_account_id = 123456789012
//
// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to
// retrieve credentials. For example:
//
// sess, err := session.NewSessionWithOptions(session.Options{
// SharedConfigState: session.SharedConfigEnable,
// Profile: "devsso",
// })
// if err != nil {
// return err
// }
//
// Programmatically loading AWS SSO credentials directly
//
// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information
// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache.
//
// svc := sso.New(sess, &aws.Config{
// Region: aws.String("us-west-2"), // Client Region must correspond to the AWS SSO user portal region
// })
//
// provider := ssocreds.NewCredentialsWithClient(svc, "123456789012", "SSOReadOnlyRole", "https://my-sso-portal.awsapps.com/start")
//
// credentials, err := provider.Get()
// if err != nil {
// return err
// }
//
// Additional Resources
//
// Configuring the AWS CLI to use AWS Single Sign-On: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html
//
// AWS Single Sign-On User Guide: https://docs.aws.amazon.com/singlesignon/latest/userguide/what-is.html
package ssocreds

View File

@ -0,0 +1,10 @@
//go:build !windows
// +build !windows
package ssocreds
import "os"
func getHomeDirectory() string {
return os.Getenv("HOME")
}

View File

@ -0,0 +1,7 @@
package ssocreds
import "os"
func getHomeDirectory() string {
return os.Getenv("USERPROFILE")
}

View File

@ -0,0 +1,180 @@
package ssocreds
import (
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sso"
"github.com/aws/aws-sdk-go/service/sso/ssoiface"
)
// ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid.
// To refresh the SSO session run aws sso login with the corresponding profile.
const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken"
const invalidTokenMessage = "the SSO session has expired or is invalid"
func init() {
nowTime = time.Now
defaultCacheLocation = defaultCacheLocationImpl
}
var nowTime func() time.Time
// ProviderName is the name of the provider used to specify the source of credentials.
const ProviderName = "SSOProvider"
var defaultCacheLocation func() string
func defaultCacheLocationImpl() string {
return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")
}
// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
type Provider struct {
credentials.Expiry
// The Client which is configured for the AWS Region where the AWS SSO user portal is located.
Client ssoiface.SSOAPI
// The AWS account that is assigned to the user.
AccountID string
// The role name that is assigned to the user.
RoleName string
// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string
}
// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
// for the AWS Region where the AWS SSO user portal is located.
func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...)
}
// NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured
// for the AWS Region where the AWS SSO user portal is located.
func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
p := &Provider{
Client: client,
AccountID: accountID,
RoleName: roleName,
StartURL: startURL,
}
for _, fn := range optFns {
fn(p)
}
return credentials.NewCredentials(p)
}
// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
tokenFile, err := loadTokenFile(p.StartURL)
if err != nil {
return credentials.Value{}, err
}
output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccountId: &p.AccountID,
RoleName: &p.RoleName,
})
if err != nil {
return credentials.Value{}, err
}
expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC()
p.SetExpiration(expireTime, 0)
return credentials.Value{
AccessKeyID: aws.StringValue(output.RoleCredentials.AccessKeyId),
SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey),
SessionToken: aws.StringValue(output.RoleCredentials.SessionToken),
ProviderName: ProviderName,
}, nil
}
func getCacheFileName(url string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
}
type rfc3339 time.Time
func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string
if err := json.Unmarshal(bytes, &value); err != nil {
return err
}
parse, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("expected RFC3339 timestamp: %v", err)
}
*r = rfc3339(parse)
return nil
}
type token struct {
AccessToken string `json:"accessToken"`
ExpiresAt rfc3339 `json:"expiresAt"`
Region string `json:"region,omitempty"`
StartURL string `json:"startUrl,omitempty"`
}
func (t token) Expired() bool {
return nowTime().Round(0).After(time.Time(t.ExpiresAt))
}
func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
if err := json.Unmarshal(fileBytes, &t); err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
if len(t.AccessToken) == 0 {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
}
if t.Expired() {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
}
return t, nil
}

View File

@ -95,7 +95,7 @@ import (
// StdinTokenProvider will prompt on stderr and read from stdin for a string value. // StdinTokenProvider will prompt on stderr and read from stdin for a string value.
// An error is returned if reading from stdin fails. // An error is returned if reading from stdin fails.
// //
// Use this function go read MFA tokens from stdin. The function makes no attempt // Use this function to read MFA tokens from stdin. The function makes no attempt
// to make atomic prompts from stdin across multiple gorouties. // to make atomic prompts from stdin across multiple gorouties.
// //
// Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will // Using StdinTokenProvider with multiple AssumeRoleProviders, or Credentials will
@ -244,9 +244,11 @@ type AssumeRoleProvider struct {
MaxJitterFrac float64 MaxJitterFrac float64
} }
// NewCredentials returns a pointer to a new Credentials object wrapping the // NewCredentials returns a pointer to a new Credentials value wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the // AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation. // role will be named after a nanosecond timestamp of this operation. The
// Credentials value will attempt to refresh the credentials using the provider
// when Credentials.Get is called, if the cached credentials are expiring.
// //
// Takes a Config provider to create the STS client. The ConfigProvider is // Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type. // satisfied by the session.Session type.
@ -268,9 +270,11 @@ func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*As
return credentials.NewCredentials(p) return credentials.NewCredentials(p)
} }
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping the // NewCredentialsWithClient returns a pointer to a new Credentials value wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the // AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation. // role will be named after a nanosecond timestamp of this operation. The
// Credentials value will attempt to refresh the credentials using the provider
// when Credentials.Get is called, if the cached credentials are expiring.
// //
// Takes an AssumeRoler which can be satisfied by the STS client. // Takes an AssumeRoler which can be satisfied by the STS client.
// //

View File

@ -28,7 +28,7 @@ const (
// compare test values. // compare test values.
var now = time.Now var now = time.Now
// TokenFetcher shuold return WebIdentity token bytes or an error // TokenFetcher should return WebIdentity token bytes or an error
type TokenFetcher interface { type TokenFetcher interface {
FetchToken(credentials.Context) ([]byte, error) FetchToken(credentials.Context) ([]byte, error)
} }
@ -50,6 +50,8 @@ func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
// an OIDC token. // an OIDC token.
type WebIdentityRoleProvider struct { type WebIdentityRoleProvider struct {
credentials.Expiry credentials.Expiry
// The policy ARNs to use with the web identity assumed role.
PolicyArns []*sts.PolicyDescriptorType PolicyArns []*sts.PolicyDescriptorType
// Duration the STS credentials will be valid for. Truncated to seconds. // Duration the STS credentials will be valid for. Truncated to seconds.
@ -74,6 +76,9 @@ type WebIdentityRoleProvider struct {
// NewWebIdentityCredentials will return a new set of credentials with a given // NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path. // configuration, role arn, and token file path.
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options, and wrap with credentials.NewCredentials helper.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials { func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c) svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
@ -82,19 +87,42 @@ func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName
// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the // NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI // provided stsiface.STSAPI
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider { func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path)) return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, FetchTokenPath(path))
} }
// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the // NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI and a TokenFetcher // provided stsiface.STSAPI and a TokenFetcher
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider { func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{ return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, tokenFetcher)
}
// NewWebIdentityRoleProviderWithOptions will return an initialize
// WebIdentityRoleProvider with the provided stsiface.STSAPI, role ARN, and a
// TokenFetcher. Additional options can be provided as functional options.
//
// TokenFetcher is the implementation that will retrieve the JWT token from to
// assume the role with. Use the provided FetchTokenPath implementation to
// retrieve the JWT token using a file system path.
func NewWebIdentityRoleProviderWithOptions(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher, optFns ...func(*WebIdentityRoleProvider)) *WebIdentityRoleProvider {
p := WebIdentityRoleProvider{
client: svc, client: svc,
tokenFetcher: tokenFetcher, tokenFetcher: tokenFetcher,
roleARN: roleARN, roleARN: roleARN,
roleSessionName: roleSessionName, roleSessionName: roleSessionName,
} }
for _, fn := range optFns {
fn(&p)
}
return &p
} }
// Retrieve will attempt to assume a role from a token which is located at // Retrieve will attempt to assume a role from a token which is located at
@ -104,9 +132,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext()) return p.RetrieveWithContext(aws.BackgroundContext())
} }
// RetrieveWithContext will attempt to assume a role from a token which is located at // RetrieveWithContext will attempt to assume a role from a token which is
// 'WebIdentityTokenFilePath' specified destination and if that is empty an // located at 'WebIdentityTokenFilePath' specified destination and if that is
// error will be returned. // empty an error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := p.tokenFetcher.FetchToken(ctx) b, err := p.tokenFetcher.FetchToken(ctx)
if err != nil { if err != nil {

View File

@ -34,7 +34,10 @@ func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
return Endpoint{}, false return Endpoint{}, false
} }
c.endpoints.Store(endpointKey, endpoint) ev := endpoint.(Endpoint)
ev.Prune()
c.endpoints.Store(endpointKey, ev)
return endpoint.(Endpoint), true return endpoint.(Endpoint), true
} }

View File

@ -60,12 +60,32 @@ func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
continue continue
} }
we.URL = cloneURL(we.URL)
return we, true return we, true
} }
return WeightedAddress{}, false return WeightedAddress{}, false
} }
// Prune will prune the expired addresses from the endpoint by allocating a new []WeightAddress.
// This is not concurrent safe, and should be called from a single owning thread.
func (e *Endpoint) Prune() bool {
validLen := e.Len()
if validLen == len(e.Addresses) {
return false
}
wa := make([]WeightedAddress, 0, validLen)
for i := range e.Addresses {
if e.Addresses[i].HasExpired() {
continue
}
wa = append(wa, e.Addresses[i])
}
e.Addresses = wa
return true
}
// Discoverer is an interface used to discovery which endpoint hit. This // Discoverer is an interface used to discovery which endpoint hit. This
// allows for specifics about what parameters need to be used to be contained // allows for specifics about what parameters need to be used to be contained
// in the Discoverer implementor. // in the Discoverer implementor.
@ -97,3 +117,16 @@ func BuildEndpointKey(params map[string]*string) string {
return strings.Join(values, ".") return strings.Join(values, ".")
} }
func cloneURL(u *url.URL) (clone *url.URL) {
clone = &url.URL{}
*clone = *u
if u.User != nil {
user := *u.User
clone.User = &user
}
return clone
}

View File

@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9 // +build go1.9
package crr package crr

View File

@ -1,3 +1,4 @@
//go:build !go1.9
// +build !go1.9 // +build !go1.9
package crr package crr

View File

@ -13,7 +13,6 @@ package ec2metadata
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -234,7 +233,8 @@ func unmarshalError(r *request.Request) {
// Response body format is not consistent between metadata endpoints. // Response body format is not consistent between metadata endpoints.
// Grab the error message as a string and include that as the source error // Grab the error message as a string and include that as the source error
r.Error = awserr.NewRequestFailure(awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())), r.Error = awserr.NewRequestFailure(
awserr.New("EC2MetadataError", "failed to make EC2Metadata request\n"+b.String(), nil),
r.HTTPResponse.StatusCode, r.RequestID) r.HTTPResponse.StatusCode, r.RequestID)
} }

View File

@ -31,12 +31,12 @@ func (d *DecodeModelOptions) Set(optFns ...func(*DecodeModelOptions)) {
// allow you to get a list of the partitions in the order the endpoints // allow you to get a list of the partitions in the order the endpoints
// will be resolved in. // will be resolved in.
// //
// resolver, err := endpoints.DecodeModel(reader) // resolver, err := endpoints.DecodeModel(reader)
// //
// partitions := resolver.(endpoints.EnumPartitions).Partitions() // partitions := resolver.(endpoints.EnumPartitions).Partitions()
// for _, p := range partitions { // for _, p := range partitions {
// // ... inspect partitions // // ... inspect partitions
// } // }
func DecodeModel(r io.Reader, optFns ...func(*DecodeModelOptions)) (Resolver, error) { func DecodeModel(r io.Reader, optFns ...func(*DecodeModelOptions)) (Resolver, error) {
var opts DecodeModelOptions var opts DecodeModelOptions
opts.Set(optFns...) opts.Set(optFns...)
@ -81,8 +81,6 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
// Customization // Customization
for i := 0; i < len(ps); i++ { for i := 0; i < len(ps); i++ {
p := &ps[i] p := &ps[i]
custAddEC2Metadata(p)
custAddS3DualStack(p)
custRegionalS3(p) custRegionalS3(p)
custRmIotDataService(p) custRmIotDataService(p)
custFixAppAutoscalingChina(p) custFixAppAutoscalingChina(p)
@ -92,15 +90,6 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
return ps, nil return ps, nil
} }
func custAddS3DualStack(p *partition) {
if !(p.ID == "aws" || p.ID == "aws-cn" || p.ID == "aws-us-gov") {
return
}
custAddDualstack(p, "s3")
custAddDualstack(p, "s3-control")
}
func custRegionalS3(p *partition) { func custRegionalS3(p *partition) {
if p.ID != "aws" { if p.ID != "aws" {
return return
@ -111,48 +100,28 @@ func custRegionalS3(p *partition) {
return return
} }
const awsGlobal = "aws-global"
const usEast1 = "us-east-1"
// If global endpoint already exists no customization needed. // If global endpoint already exists no customization needed.
if _, ok := service.Endpoints["aws-global"]; ok { if _, ok := service.Endpoints[endpointKey{Region: awsGlobal}]; ok {
return return
} }
service.PartitionEndpoint = "aws-global" service.PartitionEndpoint = awsGlobal
service.Endpoints["us-east-1"] = endpoint{} if _, ok := service.Endpoints[endpointKey{Region: usEast1}]; !ok {
service.Endpoints["aws-global"] = endpoint{ service.Endpoints[endpointKey{Region: usEast1}] = endpoint{}
}
service.Endpoints[endpointKey{Region: awsGlobal}] = endpoint{
Hostname: "s3.amazonaws.com", Hostname: "s3.amazonaws.com",
CredentialScope: credentialScope{ CredentialScope: credentialScope{
Region: "us-east-1", Region: usEast1,
}, },
} }
p.Services["s3"] = service p.Services["s3"] = service
} }
func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName]
if !ok {
return
}
s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services[svcName] = s
}
func custAddEC2Metadata(p *partition) {
p.Services["ec2metadata"] = service{
IsRegionalized: boxedFalse,
PartitionEndpoint: "aws-global",
Endpoints: endpoints{
"aws-global": endpoint{
Hostname: "169.254.169.254/latest",
Protocols: []string{"http"},
},
},
}
}
func custRmIotDataService(p *partition) { func custRmIotDataService(p *partition) {
delete(p.Services, "data.iot") delete(p.Services, "data.iot")
} }
@ -169,12 +138,13 @@ func custFixAppAutoscalingChina(p *partition) {
} }
const expectHostname = `autoscaling.{region}.amazonaws.com` const expectHostname = `autoscaling.{region}.amazonaws.com`
if e, a := s.Defaults.Hostname, expectHostname; e != a { serviceDefault := s.Defaults[defaultKey{}]
if e, a := expectHostname, serviceDefault.Hostname; e != a {
fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a) fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a)
return return
} }
serviceDefault.Hostname = expectHostname + ".cn"
s.Defaults.Hostname = expectHostname + ".cn" s.Defaults[defaultKey{}] = serviceDefault
p.Services[serviceName] = s p.Services[serviceName] = s
} }
@ -189,18 +159,25 @@ func custFixAppAutoscalingUsGov(p *partition) {
return return
} }
if a := s.Defaults.CredentialScope.Service; a != "" { serviceDefault := s.Defaults[defaultKey{}]
if a := serviceDefault.CredentialScope.Service; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a) fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
return return
} }
if a := s.Defaults.Hostname; a != "" { if a := serviceDefault.Hostname; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a) fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
return return
} }
s.Defaults.CredentialScope.Service = "application-autoscaling" serviceDefault.CredentialScope.Service = "application-autoscaling"
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com" serviceDefault.Hostname = "autoscaling.{region}.amazonaws.com"
if s.Defaults == nil {
s.Defaults = make(endpointDefaults)
}
s.Defaults[defaultKey{}] = serviceDefault
p.Services[serviceName] = s p.Services[serviceName] = s
} }

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@
// AWS GovCloud (US) (aws-us-gov). // AWS GovCloud (US) (aws-us-gov).
// . // .
// //
// Enumerating Regions and Endpoint Metadata // # Enumerating Regions and Endpoint Metadata
// //
// Casting the Resolver returned by DefaultResolver to a EnumPartitions interface // Casting the Resolver returned by DefaultResolver to a EnumPartitions interface
// will allow you to get access to the list of underlying Partitions with the // will allow you to get access to the list of underlying Partitions with the
@ -17,22 +17,22 @@
// resolving to a single partition, or enumerate regions, services, and endpoints // resolving to a single partition, or enumerate regions, services, and endpoints
// in the partition. // in the partition.
// //
// resolver := endpoints.DefaultResolver() // resolver := endpoints.DefaultResolver()
// partitions := resolver.(endpoints.EnumPartitions).Partitions() // partitions := resolver.(endpoints.EnumPartitions).Partitions()
// //
// for _, p := range partitions { // for _, p := range partitions {
// fmt.Println("Regions for", p.ID()) // fmt.Println("Regions for", p.ID())
// for id, _ := range p.Regions() { // for id, _ := range p.Regions() {
// fmt.Println("*", id) // fmt.Println("*", id)
// } // }
// //
// fmt.Println("Services for", p.ID()) // fmt.Println("Services for", p.ID())
// for id, _ := range p.Services() { // for id, _ := range p.Services() {
// fmt.Println("*", id) // fmt.Println("*", id)
// } // }
// } // }
// //
// Using Custom Endpoints // # Using Custom Endpoints
// //
// The endpoints package also gives you the ability to use your own logic how // The endpoints package also gives you the ability to use your own logic how
// endpoints are resolved. This is a great way to define a custom endpoint // endpoints are resolved. This is a great way to define a custom endpoint
@ -47,20 +47,19 @@
// of Resolver.EndpointFor, converting it to a type that satisfies the // of Resolver.EndpointFor, converting it to a type that satisfies the
// Resolver interface. // Resolver interface.
// //
// myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
// if service == endpoints.S3ServiceID {
// return endpoints.ResolvedEndpoint{
// URL: "s3.custom.endpoint.com",
// SigningRegion: "custom-signing-region",
// }, nil
// }
// //
// myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { // return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
// if service == endpoints.S3ServiceID { // }
// return endpoints.ResolvedEndpoint{
// URL: "s3.custom.endpoint.com",
// SigningRegion: "custom-signing-region",
// }, nil
// }
// //
// return endpoints.DefaultResolver().EndpointFor(service, region, optFns...) // sess := session.Must(session.NewSession(&aws.Config{
// } // Region: aws.String("us-west-2"),
// // EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
// sess := session.Must(session.NewSession(&aws.Config{ // }))
// Region: aws.String("us-west-2"),
// EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
// }))
package endpoints package endpoints

View File

@ -8,6 +8,41 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
) )
// A Logger is a minimalistic interface for the SDK to log messages to.
type Logger interface {
Log(...interface{})
}
// DualStackEndpointState is a constant to describe the dual-stack endpoint resolution
// behavior.
type DualStackEndpointState uint
const (
// DualStackEndpointStateUnset is the default value behavior for dual-stack endpoint
// resolution.
DualStackEndpointStateUnset DualStackEndpointState = iota
// DualStackEndpointStateEnabled enable dual-stack endpoint resolution for endpoints.
DualStackEndpointStateEnabled
// DualStackEndpointStateDisabled disables dual-stack endpoint resolution for endpoints.
DualStackEndpointStateDisabled
)
// FIPSEndpointState is a constant to describe the FIPS endpoint resolution behavior.
type FIPSEndpointState uint
const (
// FIPSEndpointStateUnset is the default value behavior for FIPS endpoint resolution.
FIPSEndpointStateUnset FIPSEndpointState = iota
// FIPSEndpointStateEnabled enables FIPS endpoint resolution for service endpoints.
FIPSEndpointStateEnabled
// FIPSEndpointStateDisabled disables FIPS endpoint resolution for endpoints.
FIPSEndpointStateDisabled
)
// Options provide the configuration needed to direct how the // Options provide the configuration needed to direct how the
// endpoints will be resolved. // endpoints will be resolved.
type Options struct { type Options struct {
@ -21,8 +56,19 @@ type Options struct {
// be returned. This endpoint may not be valid. If StrictMatching is // be returned. This endpoint may not be valid. If StrictMatching is
// enabled only services that are known to support dualstack will return // enabled only services that are known to support dualstack will return
// dualstack endpoints. // dualstack endpoints.
//
// Deprecated: This option will continue to function for S3 and S3 Control for backwards compatibility.
// UseDualStackEndpoint should be used to enable usage of a service's dual-stack endpoint for all service clients
// moving forward. For S3 and S3 Control, when UseDualStackEndpoint is set to a non-zero value it takes higher
// precedence then this option.
UseDualStack bool UseDualStack bool
// Sets the resolver to resolve a dual-stack endpoint for the service.
UseDualStackEndpoint DualStackEndpointState
// UseFIPSEndpoint specifies the resolver must resolve a FIPS endpoint.
UseFIPSEndpoint FIPSEndpointState
// Enables strict matching of services and regions resolved endpoints. // Enables strict matching of services and regions resolved endpoints.
// If the partition doesn't enumerate the exact service and region an // If the partition doesn't enumerate the exact service and region an
// error will be returned. This option will prevent returning endpoints // error will be returned. This option will prevent returning endpoints
@ -48,11 +94,65 @@ type Options struct {
// This option is ignored if StrictMatching is enabled. // This option is ignored if StrictMatching is enabled.
ResolveUnknownService bool ResolveUnknownService bool
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
EC2MetadataEndpointMode EC2IMDSEndpointModeState
// STS Regional Endpoint flag helps with resolving the STS endpoint // STS Regional Endpoint flag helps with resolving the STS endpoint
STSRegionalEndpoint STSRegionalEndpoint STSRegionalEndpoint STSRegionalEndpoint
// S3 Regional Endpoint flag helps with resolving the S3 endpoint // S3 Regional Endpoint flag helps with resolving the S3 endpoint
S3UsEast1RegionalEndpoint S3UsEast1RegionalEndpoint S3UsEast1RegionalEndpoint S3UsEast1RegionalEndpoint
// ResolvedRegion is the resolved region string. If provided (non-zero length) it takes priority
// over the region name passed to the ResolveEndpoint call.
ResolvedRegion string
// Logger is the logger that will be used to log messages.
Logger Logger
// Determines whether logging of deprecated endpoints usage is enabled.
LogDeprecated bool
}
func (o Options) getEndpointVariant(service string) (v endpointVariant) {
const s3 = "s3"
const s3Control = "s3-control"
if (o.UseDualStackEndpoint == DualStackEndpointStateEnabled) ||
((service == s3 || service == s3Control) && (o.UseDualStackEndpoint == DualStackEndpointStateUnset && o.UseDualStack)) {
v |= dualStackVariant
}
if o.UseFIPSEndpoint == FIPSEndpointStateEnabled {
v |= fipsVariant
}
return v
}
// EC2IMDSEndpointModeState is an enum configuration variable describing the client endpoint mode.
type EC2IMDSEndpointModeState uint
// Enumeration values for EC2IMDSEndpointModeState
const (
EC2IMDSEndpointModeStateUnset EC2IMDSEndpointModeState = iota
EC2IMDSEndpointModeStateIPv4
EC2IMDSEndpointModeStateIPv6
)
// SetFromString sets the EC2IMDSEndpointModeState based on the provided string value. Unknown values will default to EC2IMDSEndpointModeStateUnset
func (e *EC2IMDSEndpointModeState) SetFromString(v string) error {
v = strings.TrimSpace(v)
switch {
case len(v) == 0:
*e = EC2IMDSEndpointModeStateUnset
case strings.EqualFold(v, "IPv6"):
*e = EC2IMDSEndpointModeStateIPv6
case strings.EqualFold(v, "IPv4"):
*e = EC2IMDSEndpointModeStateIPv4
default:
return fmt.Errorf("unknown EC2 IMDS endpoint mode, must be either IPv6 or IPv4")
}
return nil
} }
// STSRegionalEndpoint is an enum for the states of the STS Regional Endpoint // STSRegionalEndpoint is an enum for the states of the STS Regional Endpoint
@ -166,10 +266,25 @@ func DisableSSLOption(o *Options) {
// UseDualStackOption sets the UseDualStack option. Can be used as a functional // UseDualStackOption sets the UseDualStack option. Can be used as a functional
// option when resolving endpoints. // option when resolving endpoints.
//
// Deprecated: UseDualStackEndpointOption should be used to enable usage of a service's dual-stack endpoint.
// When DualStackEndpointState is set to a non-zero value it takes higher precedence then this option.
func UseDualStackOption(o *Options) { func UseDualStackOption(o *Options) {
o.UseDualStack = true o.UseDualStack = true
} }
// UseDualStackEndpointOption sets the UseDualStackEndpoint option to enabled. Can be used as a functional
// option when resolving endpoints.
func UseDualStackEndpointOption(o *Options) {
o.UseDualStackEndpoint = DualStackEndpointStateEnabled
}
// UseFIPSEndpointOption sets the UseFIPSEndpoint option to enabled. Can be used as a functional
// option when resolving endpoints.
func UseFIPSEndpointOption(o *Options) {
o.UseFIPSEndpoint = FIPSEndpointStateEnabled
}
// StrictMatchingOption sets the StrictMatching option. Can be used as a functional // StrictMatchingOption sets the StrictMatching option. Can be used as a functional
// option when resolving endpoints. // option when resolving endpoints.
func StrictMatchingOption(o *Options) { func StrictMatchingOption(o *Options) {
@ -238,16 +353,18 @@ type EnumPartitions interface {
// as the second parameter. // as the second parameter.
// //
// This example shows how to get the regions for DynamoDB in the AWS partition. // This example shows how to get the regions for DynamoDB in the AWS partition.
// rs, exists := endpoints.RegionsForService(endpoints.DefaultPartitions(), endpoints.AwsPartitionID, endpoints.DynamodbServiceID) //
// rs, exists := endpoints.RegionsForService(endpoints.DefaultPartitions(), endpoints.AwsPartitionID, endpoints.DynamodbServiceID)
// //
// This is equivalent to using the partition directly. // This is equivalent to using the partition directly.
// rs := endpoints.AwsPartition().Services()[endpoints.DynamodbServiceID].Regions() //
// rs := endpoints.AwsPartition().Services()[endpoints.DynamodbServiceID].Regions()
func RegionsForService(ps []Partition, partitionID, serviceID string) (map[string]Region, bool) { func RegionsForService(ps []Partition, partitionID, serviceID string) (map[string]Region, bool) {
for _, p := range ps { for _, p := range ps {
if p.ID() != partitionID { if p.ID() != partitionID {
continue continue
} }
if _, ok := p.p.Services[serviceID]; !ok { if _, ok := p.p.Services[serviceID]; !(ok || serviceID == Ec2metadataServiceID) {
break break
} }
@ -308,8 +425,8 @@ func (p Partition) ID() string { return p.id }
// of new regions and services expansions. // of new regions and services expansions.
// //
// Errors that can be returned. // Errors that can be returned.
// * UnknownServiceError // - UnknownServiceError
// * UnknownEndpointError // - UnknownEndpointError
func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) { func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return p.p.EndpointFor(service, region, opts...) return p.p.EndpointFor(service, region, opts...)
} }
@ -333,6 +450,7 @@ func (p Partition) Regions() map[string]Region {
// enumerating over the services in a partition. // enumerating over the services in a partition.
func (p Partition) Services() map[string]Service { func (p Partition) Services() map[string]Service {
ss := make(map[string]Service, len(p.p.Services)) ss := make(map[string]Service, len(p.p.Services))
for id := range p.p.Services { for id := range p.p.Services {
ss[id] = Service{ ss[id] = Service{
id: id, id: id,
@ -340,6 +458,15 @@ func (p Partition) Services() map[string]Service {
} }
} }
// Since we have removed the customization that injected this into the model
// we still need to pretend that this is a modeled service.
if _, ok := ss[Ec2metadataServiceID]; !ok {
ss[Ec2metadataServiceID] = Service{
id: Ec2metadataServiceID,
p: p.p,
}
}
return ss return ss
} }
@ -367,7 +494,7 @@ func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (Resolve
func (r Region) Services() map[string]Service { func (r Region) Services() map[string]Service {
ss := map[string]Service{} ss := map[string]Service{}
for id, s := range r.p.Services { for id, s := range r.p.Services {
if _, ok := s.Endpoints[r.id]; ok { if _, ok := s.Endpoints[endpointKey{Region: r.id}]; ok {
ss[id] = Service{ ss[id] = Service{
id: id, id: id,
p: r.p, p: r.p,
@ -400,10 +527,24 @@ func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (Resolve
// an URL that can be resolved to a instance of a service. // an URL that can be resolved to a instance of a service.
func (s Service) Regions() map[string]Region { func (s Service) Regions() map[string]Region {
rs := map[string]Region{} rs := map[string]Region{}
for id := range s.p.Services[s.id].Endpoints {
if r, ok := s.p.Regions[id]; ok { service, ok := s.p.Services[s.id]
rs[id] = Region{
id: id, // Since ec2metadata customization has been removed we need to check
// if it was defined in non-standard endpoints.json file. If it's not
// then we can return the empty map as there is no regional-endpoints for IMDS.
// Otherwise, we iterate need to iterate the non-standard model.
if s.id == Ec2metadataServiceID && !ok {
return rs
}
for id := range service.Endpoints {
if id.Variant != 0 {
continue
}
if r, ok := s.p.Regions[id.Region]; ok {
rs[id.Region] = Region{
id: id.Region,
desc: r.Description, desc: r.Description,
p: s.p, p: s.p,
} }
@ -421,8 +562,11 @@ func (s Service) Regions() map[string]Region {
func (s Service) Endpoints() map[string]Endpoint { func (s Service) Endpoints() map[string]Endpoint {
es := make(map[string]Endpoint, len(s.p.Services[s.id].Endpoints)) es := make(map[string]Endpoint, len(s.p.Services[s.id].Endpoints))
for id := range s.p.Services[s.id].Endpoints { for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{ if id.Variant != 0 {
id: id, continue
}
es[id.Region] = Endpoint{
id: id.Region,
serviceID: s.id, serviceID: s.id,
p: s.p, p: s.p,
} }

View File

@ -1,12 +1,46 @@
package endpoints package endpoints
import ( import (
"encoding/json"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
) )
const (
ec2MetadataEndpointIPv6 = "http://[fd00:ec2::254]/latest"
ec2MetadataEndpointIPv4 = "http://169.254.169.254/latest"
)
const dnsSuffixTemplateKey = "{dnsSuffix}"
// defaultKey is a compound map key of a variant and other values.
type defaultKey struct {
Variant endpointVariant
ServiceVariant serviceVariant
}
// endpointKey is a compound map key of a region and associated variant value.
type endpointKey struct {
Region string
Variant endpointVariant
}
// endpointVariant is a bit field to describe the endpoints attributes.
type endpointVariant uint64
// serviceVariant is a bit field to describe the service endpoint attributes.
type serviceVariant uint64
const (
// fipsVariant indicates that the endpoint is FIPS capable.
fipsVariant endpointVariant = 1 << (64 - 1 - iota)
// dualStackVariant indicates that the endpoint is DualStack capable.
dualStackVariant
)
var regionValidationRegex = regexp.MustCompile(`^[[:alnum:]]([[:alnum:]\-]*[[:alnum:]])?$`) var regionValidationRegex = regexp.MustCompile(`^[[:alnum:]]([[:alnum:]\-]*[[:alnum:]])?$`)
type partitions []partition type partitions []partition
@ -15,8 +49,12 @@ func (ps partitions) EndpointFor(service, region string, opts ...func(*Options))
var opt Options var opt Options
opt.Set(opts...) opt.Set(opts...)
if len(opt.ResolvedRegion) > 0 {
region = opt.ResolvedRegion
}
for i := 0; i < len(ps); i++ { for i := 0; i < len(ps); i++ {
if !ps[i].canResolveEndpoint(service, region, opt.StrictMatching) { if !ps[i].canResolveEndpoint(service, region, opt) {
continue continue
} }
@ -44,14 +82,76 @@ func (ps partitions) Partitions() []Partition {
return parts return parts
} }
type endpointWithVariants struct {
endpoint
Variants []endpointWithTags `json:"variants"`
}
type endpointWithTags struct {
endpoint
Tags []string `json:"tags"`
}
type endpointDefaults map[defaultKey]endpoint
func (p *endpointDefaults) UnmarshalJSON(data []byte) error {
if *p == nil {
*p = make(endpointDefaults)
}
var e endpointWithVariants
if err := json.Unmarshal(data, &e); err != nil {
return err
}
(*p)[defaultKey{Variant: 0}] = e.endpoint
e.Hostname = ""
e.DNSSuffix = ""
for _, variant := range e.Variants {
endpointVariant, unknown := parseVariantTags(variant.Tags)
if unknown {
continue
}
var ve endpoint
ve.mergeIn(e.endpoint)
ve.mergeIn(variant.endpoint)
(*p)[defaultKey{Variant: endpointVariant}] = ve
}
return nil
}
func parseVariantTags(tags []string) (ev endpointVariant, unknown bool) {
if len(tags) == 0 {
unknown = true
return
}
for _, tag := range tags {
switch {
case strings.EqualFold("fips", tag):
ev |= fipsVariant
case strings.EqualFold("dualstack", tag):
ev |= dualStackVariant
default:
unknown = true
}
}
return ev, unknown
}
type partition struct { type partition struct {
ID string `json:"partition"` ID string `json:"partition"`
Name string `json:"partitionName"` Name string `json:"partitionName"`
DNSSuffix string `json:"dnsSuffix"` DNSSuffix string `json:"dnsSuffix"`
RegionRegex regionRegex `json:"regionRegex"` RegionRegex regionRegex `json:"regionRegex"`
Defaults endpoint `json:"defaults"` Defaults endpointDefaults `json:"defaults"`
Regions regions `json:"regions"` Regions regions `json:"regions"`
Services services `json:"services"` Services services `json:"services"`
} }
func (p partition) Partition() Partition { func (p partition) Partition() Partition {
@ -62,15 +162,18 @@ func (p partition) Partition() Partition {
} }
} }
func (p partition) canResolveEndpoint(service, region string, strictMatch bool) bool { func (p partition) canResolveEndpoint(service, region string, options Options) bool {
s, hasService := p.Services[service] s, hasService := p.Services[service]
_, hasEndpoint := s.Endpoints[region] _, hasEndpoint := s.Endpoints[endpointKey{
Region: region,
Variant: options.getEndpointVariant(service),
}]
if hasEndpoint && hasService { if hasEndpoint && hasService {
return true return true
} }
if strictMatch { if options.StrictMatching {
return false return false
} }
@ -101,7 +204,17 @@ func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (
var opt Options var opt Options
opt.Set(opts...) opt.Set(opts...)
if len(opt.ResolvedRegion) > 0 {
region = opt.ResolvedRegion
}
s, hasService := p.Services[service] s, hasService := p.Services[service]
if service == Ec2metadataServiceID && !hasService {
endpoint := getEC2MetadataEndpoint(p.ID, service, opt.EC2MetadataEndpointMode)
return endpoint, nil
}
if len(service) == 0 || !(hasService || opt.ResolveUnknownService) { if len(service) == 0 || !(hasService || opt.ResolveUnknownService) {
// Only return error if the resolver will not fallback to creating // Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in. // endpoint based on service endpoint ID passed in.
@ -112,21 +225,94 @@ func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (
region = s.PartitionEndpoint region = s.PartitionEndpoint
} }
if (service == "sts" && opt.STSRegionalEndpoint != RegionalSTSEndpoint) || if r, ok := isLegacyGlobalRegion(service, region, opt); ok {
(service == "s3" && opt.S3UsEast1RegionalEndpoint != RegionalS3UsEast1Endpoint) { region = r
}
variant := opt.getEndpointVariant(service)
endpoints := s.Endpoints
serviceDefaults, hasServiceDefault := s.Defaults[defaultKey{Variant: variant}]
// If we searched for a variant which may have no explicit service defaults,
// then we need to inherit the standard service defaults except the hostname and dnsSuffix
if variant != 0 && !hasServiceDefault {
serviceDefaults = s.Defaults[defaultKey{}]
serviceDefaults.Hostname = ""
serviceDefaults.DNSSuffix = ""
}
partitionDefaults, hasPartitionDefault := p.Defaults[defaultKey{Variant: variant}]
var dnsSuffix string
if len(serviceDefaults.DNSSuffix) > 0 {
dnsSuffix = serviceDefaults.DNSSuffix
} else if variant == 0 {
// For legacy reasons the partition dnsSuffix is not in the defaults, so if we looked for
// a non-variant endpoint then we need to set the dnsSuffix.
dnsSuffix = p.DNSSuffix
}
noDefaults := !hasServiceDefault && !hasPartitionDefault
e, hasEndpoint := s.endpointForRegion(region, endpoints, variant)
if len(region) == 0 || (!hasEndpoint && (opt.StrictMatching || noDefaults)) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(endpoints, variant))
}
defs := []endpoint{partitionDefaults, serviceDefaults}
return e.resolve(service, p.ID, region, dnsSuffixTemplateKey, dnsSuffix, defs, opt)
}
func getEC2MetadataEndpoint(partitionID, service string, mode EC2IMDSEndpointModeState) ResolvedEndpoint {
switch mode {
case EC2IMDSEndpointModeStateIPv6:
return ResolvedEndpoint{
URL: ec2MetadataEndpointIPv6,
PartitionID: partitionID,
SigningRegion: "aws-global",
SigningName: service,
SigningNameDerived: true,
SigningMethod: "v4",
}
case EC2IMDSEndpointModeStateIPv4:
fallthrough
default:
return ResolvedEndpoint{
URL: ec2MetadataEndpointIPv4,
PartitionID: partitionID,
SigningRegion: "aws-global",
SigningName: service,
SigningNameDerived: true,
SigningMethod: "v4",
}
}
}
func isLegacyGlobalRegion(service string, region string, opt Options) (string, bool) {
if opt.getEndpointVariant(service) != 0 {
return "", false
}
const (
sts = "sts"
s3 = "s3"
awsGlobal = "aws-global"
)
switch {
case service == sts && opt.STSRegionalEndpoint == RegionalSTSEndpoint:
return region, false
case service == s3 && opt.S3UsEast1RegionalEndpoint == RegionalS3UsEast1Endpoint:
return region, false
default:
if _, ok := legacyGlobalRegions[service][region]; ok { if _, ok := legacyGlobalRegions[service][region]; ok {
region = "aws-global" return awsGlobal, true
} }
} }
e, hasEndpoint := s.endpointForRegion(region) return region, false
if len(region) == 0 || (!hasEndpoint && opt.StrictMatching) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
}
defs := []endpoint{p.Defaults, s.Defaults}
return e.resolve(service, p.ID, region, p.DNSSuffix, defs, opt)
} }
func serviceList(ss services) []string { func serviceList(ss services) []string {
@ -136,10 +322,13 @@ func serviceList(ss services) []string {
} }
return list return list
} }
func endpointList(es endpoints) []string { func endpointList(es serviceEndpoints, variant endpointVariant) []string {
list := make([]string, 0, len(es)) list := make([]string, 0, len(es))
for k := range es { for k := range es {
list = append(list, k) if k.Variant != variant {
continue
}
list = append(list, k.Region)
} }
return list return list
} }
@ -171,19 +360,19 @@ type region struct {
type services map[string]service type services map[string]service
type service struct { type service struct {
PartitionEndpoint string `json:"partitionEndpoint"` PartitionEndpoint string `json:"partitionEndpoint"`
IsRegionalized boxedBool `json:"isRegionalized,omitempty"` IsRegionalized boxedBool `json:"isRegionalized,omitempty"`
Defaults endpoint `json:"defaults"` Defaults endpointDefaults `json:"defaults"`
Endpoints endpoints `json:"endpoints"` Endpoints serviceEndpoints `json:"endpoints"`
} }
func (s *service) endpointForRegion(region string) (endpoint, bool) { func (s *service) endpointForRegion(region string, endpoints serviceEndpoints, variant endpointVariant) (endpoint, bool) {
if s.IsRegionalized == boxedFalse { if e, ok := endpoints[endpointKey{Region: region, Variant: variant}]; ok {
return s.Endpoints[s.PartitionEndpoint], region == s.PartitionEndpoint return e, true
} }
if e, ok := s.Endpoints[region]; ok { if s.IsRegionalized == boxedFalse {
return e, true return endpoints[endpointKey{Region: s.PartitionEndpoint, Variant: variant}], region == s.PartitionEndpoint
} }
// Unable to find any matching endpoint, return // Unable to find any matching endpoint, return
@ -191,22 +380,73 @@ func (s *service) endpointForRegion(region string) (endpoint, bool) {
return endpoint{}, false return endpoint{}, false
} }
type endpoints map[string]endpoint type serviceEndpoints map[endpointKey]endpoint
func (s *serviceEndpoints) UnmarshalJSON(data []byte) error {
if *s == nil {
*s = make(serviceEndpoints)
}
var regionToEndpoint map[string]endpointWithVariants
if err := json.Unmarshal(data, &regionToEndpoint); err != nil {
return err
}
for region, e := range regionToEndpoint {
(*s)[endpointKey{Region: region}] = e.endpoint
e.Hostname = ""
e.DNSSuffix = ""
for _, variant := range e.Variants {
endpointVariant, unknown := parseVariantTags(variant.Tags)
if unknown {
continue
}
var ve endpoint
ve.mergeIn(e.endpoint)
ve.mergeIn(variant.endpoint)
(*s)[endpointKey{Region: region, Variant: endpointVariant}] = ve
}
}
return nil
}
type endpoint struct { type endpoint struct {
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
Protocols []string `json:"protocols"` Protocols []string `json:"protocols"`
CredentialScope credentialScope `json:"credentialScope"` CredentialScope credentialScope `json:"credentialScope"`
// Custom fields not modeled DNSSuffix string `json:"dnsSuffix"`
HasDualStack boxedBool `json:"-"`
DualStackHostname string `json:"-"`
// Signature Version not used // Signature Version not used
SignatureVersions []string `json:"signatureVersions"` SignatureVersions []string `json:"signatureVersions"`
// SSLCommonName not used. // SSLCommonName not used.
SSLCommonName string `json:"sslCommonName"` SSLCommonName string `json:"sslCommonName"`
Deprecated boxedBool `json:"deprecated"`
}
// isZero returns whether the endpoint structure is an empty (zero) value.
func (e endpoint) isZero() bool {
switch {
case len(e.Hostname) != 0:
return false
case len(e.Protocols) != 0:
return false
case e.CredentialScope != (credentialScope{}):
return false
case len(e.SignatureVersions) != 0:
return false
case len(e.SSLCommonName) != 0:
return false
}
return true
} }
const ( const (
@ -235,7 +475,7 @@ func getByPriority(s []string, p []string, def string) string {
return s[0] return s[0]
} }
func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs []endpoint, opts Options) (ResolvedEndpoint, error) { func (e endpoint) resolve(service, partitionID, region, dnsSuffixTemplateVariable, dnsSuffix string, defs []endpoint, opts Options) (ResolvedEndpoint, error) {
var merged endpoint var merged endpoint
for _, def := range defs { for _, def := range defs {
merged.mergeIn(def) merged.mergeIn(def)
@ -256,23 +496,26 @@ func (e endpoint) resolve(service, partitionID, region, dnsSuffix string, defs [
} }
hostname := e.Hostname hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
region = signingRegion
}
if !validateInputRegion(region) { if !validateInputRegion(region) {
return ResolvedEndpoint{}, fmt.Errorf("invalid region identifier format provided") return ResolvedEndpoint{}, fmt.Errorf("invalid region identifier format provided")
} }
if len(merged.DNSSuffix) > 0 {
dnsSuffix = merged.DNSSuffix
}
u := strings.Replace(hostname, "{service}", service, 1) u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1) u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1) u = strings.Replace(u, dnsSuffixTemplateVariable, dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL) scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u) u = fmt.Sprintf("%s://%s", scheme, u)
if e.Deprecated == boxedTrue && opts.LogDeprecated && opts.Logger != nil {
opts.Logger.Log(fmt.Sprintf("endpoint identifier %q, url %q marked as deprecated", region, u))
}
return ResolvedEndpoint{ return ResolvedEndpoint{
URL: u, URL: u,
PartitionID: partitionID, PartitionID: partitionID,
@ -310,11 +553,11 @@ func (e *endpoint) mergeIn(other endpoint) {
if len(other.SSLCommonName) > 0 { if len(other.SSLCommonName) > 0 {
e.SSLCommonName = other.SSLCommonName e.SSLCommonName = other.SSLCommonName
} }
if other.HasDualStack != boxedBoolUnset { if len(other.DNSSuffix) > 0 {
e.HasDualStack = other.HasDualStack e.DNSSuffix = other.DNSSuffix
} }
if len(other.DualStackHostname) > 0 { if other.Deprecated != boxedBoolUnset {
e.DualStackHostname = other.DualStackHostname e.Deprecated = other.Deprecated
} }
} }

View File

@ -1,3 +1,4 @@
//go:build codegen
// +build codegen // +build codegen
package endpoints package endpoints
@ -154,18 +155,71 @@ func serviceSet(ps partitions) map[string]struct{} {
return set return set
} }
func endpointVariantSetter(variant endpointVariant) (string, error) {
if variant == 0 {
return "0", nil
}
if variant > (fipsVariant | dualStackVariant) {
return "", fmt.Errorf("unknown endpoint variant")
}
var symbols []string
if variant&fipsVariant != 0 {
symbols = append(symbols, "fipsVariant")
}
if variant&dualStackVariant != 0 {
symbols = append(symbols, "dualStackVariant")
}
v := strings.Join(symbols, "|")
return v, nil
}
func endpointKeySetter(e endpointKey) (string, error) {
var sb strings.Builder
sb.WriteString("endpointKey{\n")
sb.WriteString(fmt.Sprintf("Region: %q,\n", e.Region))
if e.Variant != 0 {
variantSetter, err := endpointVariantSetter(e.Variant)
if err != nil {
return "", err
}
sb.WriteString(fmt.Sprintf("Variant: %s,\n", variantSetter))
}
sb.WriteString("}")
return sb.String(), nil
}
func defaultKeySetter(e defaultKey) (string, error) {
var sb strings.Builder
sb.WriteString("defaultKey{\n")
if e.Variant != 0 {
variantSetter, err := endpointVariantSetter(e.Variant)
if err != nil {
return "", err
}
sb.WriteString(fmt.Sprintf("Variant: %s,\n", variantSetter))
}
sb.WriteString("}")
return sb.String(), nil
}
var funcMap = template.FuncMap{ var funcMap = template.FuncMap{
"ToSymbol": toSymbol, "ToSymbol": toSymbol,
"QuoteString": quoteString, "QuoteString": quoteString,
"RegionConst": regionConstName, "RegionConst": regionConstName,
"PartitionGetter": partitionGetter, "PartitionGetter": partitionGetter,
"PartitionVarName": partitionVarName, "PartitionVarName": partitionVarName,
"ListPartitionNames": listPartitionNames, "ListPartitionNames": listPartitionNames,
"BoxedBoolIfSet": boxedBoolIfSet, "BoxedBoolIfSet": boxedBoolIfSet,
"StringIfSet": stringIfSet, "StringIfSet": stringIfSet,
"StringSliceIfSet": stringSliceIfSet, "StringSliceIfSet": stringSliceIfSet,
"EndpointIsSet": endpointIsSet, "EndpointIsSet": endpointIsSet,
"ServicesSet": serviceSet, "ServicesSet": serviceSet,
"EndpointVariantSetter": endpointVariantSetter,
"EndpointKeySetter": endpointKeySetter,
"DefaultKeySetter": defaultKeySetter,
} }
const v3Tmpl = ` const v3Tmpl = `
@ -271,9 +325,9 @@ partition{
{{ StringIfSet "Name: %q,\n" .Name -}} {{ StringIfSet "Name: %q,\n" .Name -}}
{{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}} {{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}}
RegionRegex: {{ template "gocode RegionRegex" .RegionRegex }}, RegionRegex: {{ template "gocode RegionRegex" .RegionRegex }},
{{ if EndpointIsSet .Defaults -}} {{ if (gt (len .Defaults) 0) -}}
Defaults: {{ template "gocode Endpoint" .Defaults }}, Defaults: {{ template "gocode Defaults" .Defaults -}},
{{- end }} {{ end -}}
Regions: {{ template "gocode Regions" .Regions }}, Regions: {{ template "gocode Regions" .Regions }},
Services: {{ template "gocode Services" .Services }}, Services: {{ template "gocode Services" .Services }},
} }
@ -314,19 +368,27 @@ services{
service{ service{
{{ StringIfSet "PartitionEndpoint: %q,\n" .PartitionEndpoint -}} {{ StringIfSet "PartitionEndpoint: %q,\n" .PartitionEndpoint -}}
{{ BoxedBoolIfSet "IsRegionalized: %s,\n" .IsRegionalized -}} {{ BoxedBoolIfSet "IsRegionalized: %s,\n" .IsRegionalized -}}
{{ if EndpointIsSet .Defaults -}} {{ if (gt (len .Defaults) 0) -}}
Defaults: {{ template "gocode Endpoint" .Defaults -}}, Defaults: {{ template "gocode Defaults" .Defaults -}},
{{- end }} {{ end -}}
{{ if .Endpoints -}} {{ if .Endpoints -}}
Endpoints: {{ template "gocode Endpoints" .Endpoints }}, Endpoints: {{ template "gocode Endpoints" .Endpoints }},
{{- end }} {{- end }}
} }
{{- end }} {{- end }}
{{ define "gocode Endpoints" -}} {{ define "gocode Defaults" -}}
endpoints{ endpointDefaults{
{{ range $id, $endpoint := . -}} {{ range $id, $endpoint := . -}}
"{{ $id }}": {{ template "gocode Endpoint" $endpoint }}, {{ DefaultKeySetter $id }}: {{ template "gocode Endpoint" $endpoint }},
{{ end }}
}
{{- end }}
{{ define "gocode Endpoints" -}}
serviceEndpoints{
{{ range $id, $endpoint := . -}}
{{ EndpointKeySetter $id }}: {{ template "gocode Endpoint" $endpoint }},
{{ end }} {{ end }}
} }
{{- end }} {{- end }}
@ -334,6 +396,7 @@ endpoints{
{{ define "gocode Endpoint" -}} {{ define "gocode Endpoint" -}}
endpoint{ endpoint{
{{ StringIfSet "Hostname: %q,\n" .Hostname -}} {{ StringIfSet "Hostname: %q,\n" .Hostname -}}
{{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}}
{{ StringIfSet "SSLCommonName: %q,\n" .SSLCommonName -}} {{ StringIfSet "SSLCommonName: %q,\n" .SSLCommonName -}}
{{ StringSliceIfSet "Protocols: []string{%s},\n" .Protocols -}} {{ StringSliceIfSet "Protocols: []string{%s},\n" .Protocols -}}
{{ StringSliceIfSet "SignatureVersions: []string{%s},\n" .SignatureVersions -}} {{ StringSliceIfSet "SignatureVersions: []string{%s},\n" .SignatureVersions -}}
@ -343,9 +406,7 @@ endpoint{
{{ StringIfSet "Service: %q,\n" .CredentialScope.Service -}} {{ StringIfSet "Service: %q,\n" .CredentialScope.Service -}}
}, },
{{- end }} {{- end }}
{{ BoxedBoolIfSet "HasDualStack: %s,\n" .HasDualStack -}} {{ BoxedBoolIfSet "Deprecated: %s,\n" .Deprecated -}}
{{ StringIfSet "DualStackHostname: %q,\n" .DualStackHostname -}}
} }
{{- end }} {{- end }}
` `

View File

@ -77,6 +77,9 @@ const (
// wire unmarshaled message content of requests and responses made while // wire unmarshaled message content of requests and responses made while
// using the SDK Will also enable LogDebug. // using the SDK Will also enable LogDebug.
LogDebugWithEventStreamBody LogDebugWithEventStreamBody
// LogDebugWithDeprecated states the SDK should log details about deprecated functionality.
LogDebugWithDeprecated
) )
// A Logger is a minimalistic interface for the SDK to log messages to. Should // A Logger is a minimalistic interface for the SDK to log messages to. Should

View File

@ -330,6 +330,9 @@ func MakeAddToUserAgentFreeFormHandler(s string) func(*Request) {
// WithSetRequestHeaders updates the operation request's HTTP header to contain // WithSetRequestHeaders updates the operation request's HTTP header to contain
// the header key value pairs provided. If the header key already exists in the // the header key value pairs provided. If the header key already exists in the
// request's HTTP header set, the existing value(s) will be replaced. // request's HTTP header set, the existing value(s) will be replaced.
//
// Header keys added will be added as canonical format with title casing
// applied via http.Header.Set method.
func WithSetRequestHeaders(h map[string]string) Option { func WithSetRequestHeaders(h map[string]string) Option {
return withRequestHeader(h).SetRequestHeaders return withRequestHeader(h).SetRequestHeaders
} }
@ -338,6 +341,6 @@ type withRequestHeader map[string]string
func (h withRequestHeader) SetRequestHeaders(r *Request) { func (h withRequestHeader) SetRequestHeaders(r *Request) {
for k, v := range h { for k, v := range h {
r.HTTPRequest.Header[k] = []string{v} r.HTTPRequest.Header.Set(k, v)
} }
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
@ -129,12 +130,27 @@ func New(cfg aws.Config, clientInfo metadata.ClientInfo, handlers Handlers,
httpReq, _ := http.NewRequest(method, "", nil) httpReq, _ := http.NewRequest(method, "", nil)
var err error var err error
httpReq.URL, err = url.Parse(clientInfo.Endpoint + operation.HTTPPath) httpReq.URL, err = url.Parse(clientInfo.Endpoint)
if err != nil { if err != nil {
httpReq.URL = &url.URL{} httpReq.URL = &url.URL{}
err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err) err = awserr.New("InvalidEndpointURL", "invalid endpoint uri", err)
} }
if len(operation.HTTPPath) != 0 {
opHTTPPath := operation.HTTPPath
var opQueryString string
if idx := strings.Index(opHTTPPath, "?"); idx >= 0 {
opQueryString = opHTTPPath[idx+1:]
opHTTPPath = opHTTPPath[:idx]
}
if strings.HasSuffix(httpReq.URL.Path, "/") && strings.HasPrefix(opHTTPPath, "/") {
opHTTPPath = opHTTPPath[1:]
}
httpReq.URL.Path += opHTTPPath
httpReq.URL.RawQuery = opQueryString
}
r := &Request{ r := &Request{
Config: cfg, Config: cfg,
ClientInfo: clientInfo, ClientInfo: clientInfo,
@ -510,6 +526,14 @@ func (r *Request) GetBody() io.ReadSeeker {
// Send will not close the request.Request's body. // Send will not close the request.Request's body.
func (r *Request) Send() error { func (r *Request) Send() error {
defer func() { defer func() {
// Ensure a non-nil HTTPResponse parameter is set to ensure handlers
// checking for HTTPResponse values, don't fail.
if r.HTTPResponse == nil {
r.HTTPResponse = &http.Response{
Header: http.Header{},
Body: ioutil.NopCloser(&bytes.Buffer{}),
}
}
// Regardless of success or failure of the request trigger the Complete // Regardless of success or failure of the request trigger the Complete
// request handlers. // request handlers.
r.Handlers.Complete.Run(r) r.Handlers.Complete.Run(r)

View File

@ -1,3 +1,4 @@
//go:build !go1.8
// +build !go1.8 // +build !go1.8
package request package request

View File

@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8 // +build go1.8
package request package request

View File

@ -1,3 +1,4 @@
//go:build go1.7
// +build go1.7 // +build go1.7
package request package request

View File

@ -1,3 +1,4 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package request package request

View File

@ -15,8 +15,8 @@ import (
// and determine if a request API error should be retried. // and determine if a request API error should be retried.
// //
// client.DefaultRetryer is the SDK's default implementation of the Retryer. It // client.DefaultRetryer is the SDK's default implementation of the Retryer. It
// uses the which uses the Request.IsErrorRetryable and Request.IsErrorThrottle // uses the Request.IsErrorRetryable and Request.IsErrorThrottle methods to
// methods to determine if the request is retried. // determine if the request is retried.
type Retryer interface { type Retryer interface {
// RetryRules return the retry delay that should be used by the SDK before // RetryRules return the retry delay that should be used by the SDK before
// making another request attempt for the failed request. // making another request attempt for the failed request.

View File

@ -23,6 +23,7 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/corehandlers", "//vendor/github.com/aws/aws-sdk-go/aws/corehandlers",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials", "//vendor/github.com/aws/aws-sdk-go/aws/credentials",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials/processcreds", "//vendor/github.com/aws/aws-sdk-go/aws/credentials/processcreds",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials/ssocreds",
"//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds", "//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds",
"//vendor/github.com/aws/aws-sdk-go/aws/csm", "//vendor/github.com/aws/aws-sdk-go/aws/csm",
"//vendor/github.com/aws/aws-sdk-go/aws/defaults", "//vendor/github.com/aws/aws-sdk-go/aws/defaults",
@ -30,5 +31,6 @@ go_library(
"//vendor/github.com/aws/aws-sdk-go/aws/request", "//vendor/github.com/aws/aws-sdk-go/aws/request",
"//vendor/github.com/aws/aws-sdk-go/internal/ini", "//vendor/github.com/aws/aws-sdk-go/internal/ini",
"//vendor/github.com/aws/aws-sdk-go/internal/shareddefaults", "//vendor/github.com/aws/aws-sdk-go/internal/shareddefaults",
"//vendor/github.com/aws/aws-sdk-go/service/sts",
], ],
) )

View File

@ -9,12 +9,22 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds" "github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/credentials/ssocreds"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults" "github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/aws/aws-sdk-go/service/sts"
) )
// CredentialsProviderOptions specifies additional options for configuring
// credentials providers.
type CredentialsProviderOptions struct {
// WebIdentityRoleProviderOptions configures a WebIdentityRoleProvider,
// such as setting its ExpiryWindow.
WebIdentityRoleProviderOptions func(*stscreds.WebIdentityRoleProvider)
}
func resolveCredentials(cfg *aws.Config, func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig, envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers, handlers request.Handlers,
@ -39,6 +49,7 @@ func resolveCredentials(cfg *aws.Config,
envCfg.WebIdentityTokenFilePath, envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN, envCfg.RoleARN,
envCfg.RoleSessionName, envCfg.RoleSessionName,
sessOpts.CredentialsProviderOptions,
) )
default: default:
@ -58,6 +69,7 @@ var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers, func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string, filepath string,
roleARN, sessionName string, roleARN, sessionName string,
credOptions *CredentialsProviderOptions,
) (*credentials.Credentials, error) { ) (*credentials.Credentials, error) {
if len(filepath) == 0 { if len(filepath) == 0 {
@ -68,17 +80,18 @@ func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
return nil, WebIdentityEmptyRoleARNErr return nil, WebIdentityEmptyRoleARNErr
} }
creds := stscreds.NewWebIdentityCredentials( svc := sts.New(&Session{
&Session{ Config: cfg,
Config: cfg, Handlers: handlers.Copy(),
Handlers: handlers.Copy(), })
},
roleARN,
sessionName,
filepath,
)
return creds, nil var optFns []func(*stscreds.WebIdentityRoleProvider)
if credOptions != nil && credOptions.WebIdentityRoleProviderOptions != nil {
optFns = append(optFns, credOptions.WebIdentityRoleProviderOptions)
}
p := stscreds.NewWebIdentityRoleProviderWithOptions(svc, roleARN, sessionName, stscreds.FetchTokenPath(filepath), optFns...)
return credentials.NewCredentials(p), nil
} }
func resolveCredsFromProfile(cfg *aws.Config, func resolveCredsFromProfile(cfg *aws.Config,
@ -100,10 +113,6 @@ func resolveCredsFromProfile(cfg *aws.Config,
sharedCfg.Creds, sharedCfg.Creds,
) )
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
case len(sharedCfg.CredentialSource) != 0: case len(sharedCfg.CredentialSource) != 0:
creds, err = resolveCredsFromSource(cfg, envCfg, creds, err = resolveCredsFromSource(cfg, envCfg,
sharedCfg, handlers, sessOpts, sharedCfg, handlers, sessOpts,
@ -117,8 +126,16 @@ func resolveCredsFromProfile(cfg *aws.Config,
sharedCfg.WebIdentityTokenFile, sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN, sharedCfg.RoleARN,
sharedCfg.RoleSessionName, sharedCfg.RoleSessionName,
sessOpts.CredentialsProviderOptions,
) )
case sharedCfg.hasSSOConfiguration():
creds, err = resolveSSOCredentials(cfg, sharedCfg, handlers)
case len(sharedCfg.CredentialProcess) != 0:
// Get credentials from CredentialProcess
creds = processcreds.NewCredentials(sharedCfg.CredentialProcess)
default: default:
// Fallback to default credentials provider, include mock errors for // Fallback to default credentials provider, include mock errors for
// the credential chain so user can identify why credentials failed to // the credential chain so user can identify why credentials failed to
@ -151,6 +168,25 @@ func resolveCredsFromProfile(cfg *aws.Config,
return creds, nil return creds, nil
} }
func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers request.Handlers) (*credentials.Credentials, error) {
if err := sharedCfg.validateSSOConfiguration(); err != nil {
return nil, err
}
cfgCopy := cfg.Copy()
cfgCopy.Region = &sharedCfg.SSORegion
return ssocreds.NewCredentials(
&Session{
Config: cfgCopy,
Handlers: handlers.Copy(),
},
sharedCfg.SSOAccountID,
sharedCfg.SSORoleName,
sharedCfg.SSOStartURL,
), nil
}
// valid credential source values // valid credential source values
const ( const (
credSourceEc2Metadata = "Ec2InstanceMetadata" credSourceEc2Metadata = "Ec2InstanceMetadata"

View File

@ -1,3 +1,4 @@
//go:build go1.13
// +build go1.13 // +build go1.13
package session package session

View File

@ -1,3 +1,4 @@
//go:build !go1.13 && go1.7
// +build !go1.13,go1.7 // +build !go1.13,go1.7
package session package session

View File

@ -1,3 +1,4 @@
//go:build !go1.6 && go1.5
// +build !go1.6,go1.5 // +build !go1.6,go1.5
package session package session

View File

@ -1,3 +1,4 @@
//go:build !go1.7 && go1.6
// +build !go1.7,go1.6 // +build !go1.7,go1.6
package session package session

View File

@ -283,7 +283,85 @@ component must be enclosed in square brackets.
The custom EC2 IMDS endpoint can also be specified via the Session options. The custom EC2 IMDS endpoint can also be specified via the Session options.
sess, err := session.NewSessionWithOptions(session.Options{ sess, err := session.NewSessionWithOptions(session.Options{
EC2IMDSEndpoint: "http://[::1]", EC2MetadataEndpoint: "http://[::1]",
})
FIPS and DualStack Endpoints
The SDK can be configured to resolve an endpoint with certain capabilities such as FIPS and DualStack.
You can configure a FIPS endpoint using an environment variable, shared config ($HOME/.aws/config),
or programmatically.
To configure a FIPS endpoint set the environment variable set the AWS_USE_FIPS_ENDPOINT to true or false to enable
or disable FIPS endpoint resolution.
AWS_USE_FIPS_ENDPOINT=true
To configure a FIPS endpoint using shared config, set use_fips_endpoint to true or false to enable
or disable FIPS endpoint resolution.
[profile myprofile]
region=us-west-2
use_fips_endpoint=true
To configure a FIPS endpoint programmatically
// Option 1: Configure it on a session for all clients
sess, err := session.NewSessionWithOptions(session.Options{
UseFIPSEndpoint: endpoints.FIPSEndpointStateEnabled,
})
if err != nil {
// handle error
}
client := s3.New(sess)
// Option 2: Configure it per client
sess, err := session.NewSession()
if err != nil {
// handle error
}
client := s3.New(sess, &aws.Config{
UseFIPSEndpoint: endpoints.FIPSEndpointStateEnabled,
})
You can configure a DualStack endpoint using an environment variable, shared config ($HOME/.aws/config),
or programmatically.
To configure a DualStack endpoint set the environment variable set the AWS_USE_DUALSTACK_ENDPOINT to true or false to
enable or disable DualStack endpoint resolution.
AWS_USE_DUALSTACK_ENDPOINT=true
To configure a DualStack endpoint using shared config, set use_dualstack_endpoint to true or false to enable
or disable DualStack endpoint resolution.
[profile myprofile]
region=us-west-2
use_dualstack_endpoint=true
To configure a DualStack endpoint programmatically
// Option 1: Configure it on a session for all clients
sess, err := session.NewSessionWithOptions(session.Options{
UseDualStackEndpoint: endpoints.DualStackEndpointStateEnabled,
})
if err != nil {
// handle error
}
client := s3.New(sess)
// Option 2: Configure it per client
sess, err := session.NewSession()
if err != nil {
// handle error
}
client := s3.New(sess, &aws.Config{
UseDualStackEndpoint: endpoints.DualStackEndpointStateEnabled,
}) })
*/ */
package session package session

View File

@ -161,10 +161,27 @@ type envConfig struct {
// AWS_S3_USE_ARN_REGION=true // AWS_S3_USE_ARN_REGION=true
S3UseARNRegion bool S3UseARNRegion bool
// Specifies the alternative endpoint to use for EC2 IMDS. // Specifies the EC2 Instance Metadata Service endpoint to use. If specified it overrides EC2IMDSEndpointMode.
// //
// AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1] // AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1]
EC2IMDSEndpoint string EC2IMDSEndpoint string
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies that SDK clients must resolve a dual-stack endpoint for
// services.
//
// AWS_USE_DUALSTACK_ENDPOINT=true
UseDualStackEndpoint endpoints.DualStackEndpointState
// Specifies that SDK clients must resolve a FIPS endpoint for
// services.
//
// AWS_USE_FIPS_ENDPOINT=true
UseFIPSEndpoint endpoints.FIPSEndpointState
} }
var ( var (
@ -231,6 +248,9 @@ var (
ec2IMDSEndpointEnvKey = []string{ ec2IMDSEndpointEnvKey = []string{
"AWS_EC2_METADATA_SERVICE_ENDPOINT", "AWS_EC2_METADATA_SERVICE_ENDPOINT",
} }
ec2IMDSEndpointModeEnvKey = []string{
"AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE",
}
useCABundleKey = []string{ useCABundleKey = []string{
"AWS_CA_BUNDLE", "AWS_CA_BUNDLE",
} }
@ -240,6 +260,12 @@ var (
useClientTLSKey = []string{ useClientTLSKey = []string{
"AWS_SDK_GO_CLIENT_TLS_KEY", "AWS_SDK_GO_CLIENT_TLS_KEY",
} }
awsUseDualStackEndpoint = []string{
"AWS_USE_DUALSTACK_ENDPOINT",
}
awsUseFIPSEndpoint = []string{
"AWS_USE_FIPS_ENDPOINT",
}
) )
// loadEnvConfig retrieves the SDK's environment configuration. // loadEnvConfig retrieves the SDK's environment configuration.
@ -364,6 +390,17 @@ func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
} }
setFromEnvVal(&cfg.EC2IMDSEndpoint, ec2IMDSEndpointEnvKey) setFromEnvVal(&cfg.EC2IMDSEndpoint, ec2IMDSEndpointEnvKey)
if err := setEC2IMDSEndpointMode(&cfg.EC2IMDSEndpointMode, ec2IMDSEndpointModeEnvKey); err != nil {
return envConfig{}, err
}
if err := setUseDualStackEndpointFromEnvVal(&cfg.UseDualStackEndpoint, awsUseDualStackEndpoint); err != nil {
return cfg, err
}
if err := setUseFIPSEndpointFromEnvVal(&cfg.UseFIPSEndpoint, awsUseFIPSEndpoint); err != nil {
return cfg, err
}
return cfg, nil return cfg, nil
} }
@ -376,3 +413,59 @@ func setFromEnvVal(dst *string, keys []string) {
} }
} }
} }
func setEC2IMDSEndpointMode(mode *endpoints.EC2IMDSEndpointModeState, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue
}
if err := mode.SetFromString(value); err != nil {
return fmt.Errorf("invalid value for environment variable, %s=%s, %v", k, value, err)
}
return nil
}
return nil
}
func setUseDualStackEndpointFromEnvVal(dst *endpoints.DualStackEndpointState, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue // skip if empty
}
switch {
case strings.EqualFold(value, "true"):
*dst = endpoints.DualStackEndpointStateEnabled
case strings.EqualFold(value, "false"):
*dst = endpoints.DualStackEndpointStateDisabled
default:
return fmt.Errorf(
"invalid value for environment variable, %s=%s, need true, false",
k, value)
}
}
return nil
}
func setUseFIPSEndpointFromEnvVal(dst *endpoints.FIPSEndpointState, keys []string) error {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue // skip if empty
}
switch {
case strings.EqualFold(value, "true"):
*dst = endpoints.FIPSEndpointStateEnabled
case strings.EqualFold(value, "false"):
*dst = endpoints.FIPSEndpointStateDisabled
default:
return fmt.Errorf(
"invalid value for environment variable, %s=%s, need true, false",
k, value)
}
}
return nil
}

View File

@ -8,6 +8,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -36,7 +37,7 @@ const (
// ErrSharedConfigSourceCollision will be returned if a section contains both // ErrSharedConfigSourceCollision will be returned if a section contains both
// source_profile and credential_source // source_profile and credential_source
var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only source profile or credential source can be specified, not both", nil) var ErrSharedConfigSourceCollision = awserr.New(ErrCodeSharedConfig, "only one credential type may be specified per profile: source profile, credential source, credential process, web identity token, or sso", nil)
// ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment // ErrSharedConfigECSContainerEnvVarEmpty will be returned if the environment
// variables are empty and Environment was set as the credential source // variables are empty and Environment was set as the credential source
@ -283,8 +284,8 @@ type Options struct {
Handlers request.Handlers Handlers request.Handlers
// Allows specifying a custom endpoint to be used by the EC2 IMDS client // Allows specifying a custom endpoint to be used by the EC2 IMDS client
// when making requests to the EC2 IMDS API. The must endpoint value must // when making requests to the EC2 IMDS API. The endpoint value should
// include protocol prefix. // include the URI scheme. If the scheme is not present it will be defaulted to http.
// //
// If unset, will the EC2 IMDS client will use its default endpoint. // If unset, will the EC2 IMDS client will use its default endpoint.
// //
@ -298,6 +299,16 @@ type Options struct {
// //
// AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1] // AWS_EC2_METADATA_SERVICE_ENDPOINT=http://[::1]
EC2IMDSEndpoint string EC2IMDSEndpoint string
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies options for creating credential providers.
// These are only used if the aws.Config does not already
// include credentials.
CredentialsProviderOptions *CredentialsProviderOptions
} }
// NewSessionWithOptions returns a new Session created from SDK defaults, config files, // NewSessionWithOptions returns a new Session created from SDK defaults, config files,
@ -375,19 +386,23 @@ func Must(sess *Session, err error) *Session {
// Wraps the endpoint resolver with a resolver that will return a custom // Wraps the endpoint resolver with a resolver that will return a custom
// endpoint for EC2 IMDS. // endpoint for EC2 IMDS.
func wrapEC2IMDSEndpoint(resolver endpoints.Resolver, endpoint string) endpoints.Resolver { func wrapEC2IMDSEndpoint(resolver endpoints.Resolver, endpoint string, mode endpoints.EC2IMDSEndpointModeState) endpoints.Resolver {
return endpoints.ResolverFunc( return endpoints.ResolverFunc(
func(service, region string, opts ...func(*endpoints.Options)) ( func(service, region string, opts ...func(*endpoints.Options)) (
endpoints.ResolvedEndpoint, error, endpoints.ResolvedEndpoint, error,
) { ) {
if service == ec2MetadataServiceID { if service == ec2MetadataServiceID && len(endpoint) > 0 {
return endpoints.ResolvedEndpoint{ return endpoints.ResolvedEndpoint{
URL: endpoint, URL: endpoint,
SigningName: ec2MetadataServiceID, SigningName: ec2MetadataServiceID,
SigningRegion: region, SigningRegion: region,
}, nil }, nil
} else if service == ec2MetadataServiceID {
opts = append(opts, func(o *endpoints.Options) {
o.EC2MetadataEndpointMode = mode
})
} }
return resolver.EndpointFor(service, region) return resolver.EndpointFor(service, region, opts...)
}) })
} }
@ -404,8 +419,8 @@ func deprecatedNewSession(envCfg envConfig, cfgs ...*aws.Config) *Session {
cfg.EndpointResolver = endpoints.DefaultResolver() cfg.EndpointResolver = endpoints.DefaultResolver()
} }
if len(envCfg.EC2IMDSEndpoint) != 0 { if !(len(envCfg.EC2IMDSEndpoint) == 0 && envCfg.EC2IMDSEndpointMode == endpoints.EC2IMDSEndpointModeStateUnset) {
cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, envCfg.EC2IMDSEndpoint) cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, envCfg.EC2IMDSEndpoint, envCfg.EC2IMDSEndpointMode)
} }
cfg.Credentials = defaults.CredChain(cfg, handlers) cfg.Credentials = defaults.CredChain(cfg, handlers)
@ -737,12 +752,32 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config,
endpoints.LegacyS3UsEast1Endpoint, endpoints.LegacyS3UsEast1Endpoint,
}) })
ec2IMDSEndpoint := sessOpts.EC2IMDSEndpoint var ec2IMDSEndpoint string
if len(ec2IMDSEndpoint) == 0 { for _, v := range []string{
ec2IMDSEndpoint = envCfg.EC2IMDSEndpoint sessOpts.EC2IMDSEndpoint,
envCfg.EC2IMDSEndpoint,
sharedCfg.EC2IMDSEndpoint,
} {
if len(v) != 0 {
ec2IMDSEndpoint = v
break
}
} }
if len(ec2IMDSEndpoint) != 0 {
cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, ec2IMDSEndpoint) var endpointMode endpoints.EC2IMDSEndpointModeState
for _, v := range []endpoints.EC2IMDSEndpointModeState{
sessOpts.EC2IMDSEndpointMode,
envCfg.EC2IMDSEndpointMode,
sharedCfg.EC2IMDSEndpointMode,
} {
if v != endpoints.EC2IMDSEndpointModeStateUnset {
endpointMode = v
break
}
}
if len(ec2IMDSEndpoint) != 0 || endpointMode != endpoints.EC2IMDSEndpointModeStateUnset {
cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, ec2IMDSEndpoint, endpointMode)
} }
// Configure credentials if not already set by the user when creating the // Configure credentials if not already set by the user when creating the
@ -763,6 +798,20 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config,
cfg.S3UseARNRegion = &sharedCfg.S3UseARNRegion cfg.S3UseARNRegion = &sharedCfg.S3UseARNRegion
} }
for _, v := range []endpoints.DualStackEndpointState{userCfg.UseDualStackEndpoint, envCfg.UseDualStackEndpoint, sharedCfg.UseDualStackEndpoint} {
if v != endpoints.DualStackEndpointStateUnset {
cfg.UseDualStackEndpoint = v
break
}
}
for _, v := range []endpoints.FIPSEndpointState{userCfg.UseFIPSEndpoint, envCfg.UseFIPSEndpoint, sharedCfg.UseFIPSEndpoint} {
if v != endpoints.FIPSEndpointStateUnset {
cfg.UseFIPSEndpoint = v
break
}
}
return nil return nil
} }
@ -816,8 +865,10 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config { func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Config {
s = s.Copy(cfgs...) s = s.Copy(cfgs...)
resolvedRegion := normalizeRegion(s.Config)
region := aws.StringValue(s.Config.Region) region := aws.StringValue(s.Config.Region)
resolved, err := s.resolveEndpoint(service, region, s.Config) resolved, err := s.resolveEndpoint(service, region, resolvedRegion, s.Config)
if err != nil { if err != nil {
s.Handlers.Validate.PushBack(func(r *request.Request) { s.Handlers.Validate.PushBack(func(r *request.Request) {
if len(r.ClientInfo.Endpoint) != 0 { if len(r.ClientInfo.Endpoint) != 0 {
@ -838,12 +889,13 @@ func (s *Session) ClientConfig(service string, cfgs ...*aws.Config) client.Confi
SigningRegion: resolved.SigningRegion, SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived, SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName, SigningName: resolved.SigningName,
ResolvedRegion: resolvedRegion,
} }
} }
const ec2MetadataServiceID = "ec2metadata" const ec2MetadataServiceID = "ec2metadata"
func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) { func (s *Session) resolveEndpoint(service, region, resolvedRegion string, cfg *aws.Config) (endpoints.ResolvedEndpoint, error) {
if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 { if ep := aws.StringValue(cfg.Endpoint); len(ep) != 0 {
return endpoints.ResolvedEndpoint{ return endpoints.ResolvedEndpoint{
@ -855,7 +907,12 @@ func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endp
resolved, err := cfg.EndpointResolver.EndpointFor(service, region, resolved, err := cfg.EndpointResolver.EndpointFor(service, region,
func(opt *endpoints.Options) { func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(cfg.DisableSSL) opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
opt.UseDualStack = aws.BoolValue(cfg.UseDualStack) opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
opt.UseDualStackEndpoint = cfg.UseDualStackEndpoint
opt.UseFIPSEndpoint = cfg.UseFIPSEndpoint
// Support for STSRegionalEndpoint where the STSRegionalEndpoint is // Support for STSRegionalEndpoint where the STSRegionalEndpoint is
// provided in envConfig or sharedConfig with envConfig getting // provided in envConfig or sharedConfig with envConfig getting
// precedence. // precedence.
@ -869,6 +926,11 @@ func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endp
// Support the condition where the service is modeled but its // Support the condition where the service is modeled but its
// endpoint metadata is not available. // endpoint metadata is not available.
opt.ResolveUnknownService = true opt.ResolveUnknownService = true
opt.ResolvedRegion = resolvedRegion
opt.Logger = cfg.Logger
opt.LogDeprecated = cfg.LogLevel.Matches(aws.LogDebugWithDeprecated)
}, },
) )
if err != nil { if err != nil {
@ -884,6 +946,8 @@ func (s *Session) resolveEndpoint(service, region string, cfg *aws.Config) (endp
func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Config { func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Config {
s = s.Copy(cfgs...) s = s.Copy(cfgs...)
resolvedRegion := normalizeRegion(s.Config)
var resolved endpoints.ResolvedEndpoint var resolved endpoints.ResolvedEndpoint
if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 { if ep := aws.StringValue(s.Config.Endpoint); len(ep) > 0 {
resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL)) resolved.URL = endpoints.AddScheme(ep, aws.BoolValue(s.Config.DisableSSL))
@ -897,6 +961,7 @@ func (s *Session) ClientConfigNoResolveEndpoint(cfgs ...*aws.Config) client.Conf
SigningRegion: resolved.SigningRegion, SigningRegion: resolved.SigningRegion,
SigningNameDerived: resolved.SigningNameDerived, SigningNameDerived: resolved.SigningNameDerived,
SigningName: resolved.SigningName, SigningName: resolved.SigningName,
ResolvedRegion: resolvedRegion,
} }
} }
@ -910,3 +975,23 @@ func (s *Session) logDeprecatedNewSessionError(msg string, err error, cfgs []*aw
r.Error = err r.Error = err
}) })
} }
// normalizeRegion resolves / normalizes the configured region (converts pseudo fips regions), and modifies the provided
// config to have the equivalent options for resolution and returns the resolved region name.
func normalizeRegion(cfg *aws.Config) (resolved string) {
const fipsInfix = "-fips-"
const fipsPrefix = "-fips"
const fipsSuffix = "fips-"
region := aws.StringValue(cfg.Region)
if strings.Contains(region, fipsInfix) ||
strings.Contains(region, fipsPrefix) ||
strings.Contains(region, fipsSuffix) {
resolved = strings.Replace(strings.Replace(strings.Replace(
region, fipsInfix, "-", -1), fipsPrefix, "", -1), fipsSuffix, "", -1)
cfg.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled
}
return resolved
}

View File

@ -2,6 +2,7 @@ package session
import ( import (
"fmt" "fmt"
"strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
@ -25,6 +26,12 @@ const (
roleSessionNameKey = `role_session_name` // optional roleSessionNameKey = `role_session_name` // optional
roleDurationSecondsKey = "duration_seconds" // optional roleDurationSecondsKey = "duration_seconds" // optional
// AWS Single Sign-On (AWS SSO) group
ssoAccountIDKey = "sso_account_id"
ssoRegionKey = "sso_region"
ssoRoleNameKey = "sso_role_name"
ssoStartURL = "sso_start_url"
// CSM options // CSM options
csmEnabledKey = `csm_enabled` csmEnabledKey = `csm_enabled`
csmHostKey = `csm_host` csmHostKey = `csm_host`
@ -59,10 +66,24 @@ const (
// S3 ARN Region Usage // S3 ARN Region Usage
s3UseARNRegionKey = "s3_use_arn_region" s3UseARNRegionKey = "s3_use_arn_region"
// EC2 IMDS Endpoint Mode
ec2MetadataServiceEndpointModeKey = "ec2_metadata_service_endpoint_mode"
// EC2 IMDS Endpoint
ec2MetadataServiceEndpointKey = "ec2_metadata_service_endpoint"
// Use DualStack Endpoint Resolution
useDualStackEndpoint = "use_dualstack_endpoint"
// Use FIPS Endpoint Resolution
useFIPSEndpointKey = "use_fips_endpoint"
) )
// sharedConfig represents the configuration fields of the SDK config files. // sharedConfig represents the configuration fields of the SDK config files.
type sharedConfig struct { type sharedConfig struct {
Profile string
// Credentials values from the config file. Both aws_access_key_id and // Credentials values from the config file. Both aws_access_key_id and
// aws_secret_access_key must be provided together in the same file to be // aws_secret_access_key must be provided together in the same file to be
// considered valid. The values will be ignored if not a complete group. // considered valid. The values will be ignored if not a complete group.
@ -78,6 +99,11 @@ type sharedConfig struct {
CredentialProcess string CredentialProcess string
WebIdentityTokenFile string WebIdentityTokenFile string
SSOAccountID string
SSORegion string
SSORoleName string
SSOStartURL string
RoleARN string RoleARN string
RoleSessionName string RoleSessionName string
ExternalID string ExternalID string
@ -131,6 +157,28 @@ type sharedConfig struct {
// //
// s3_use_arn_region=true // s3_use_arn_region=true
S3UseARNRegion bool S3UseARNRegion bool
// Specifies the EC2 Instance Metadata Service default endpoint selection mode (IPv4 or IPv6)
//
// ec2_metadata_service_endpoint_mode=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies the EC2 Instance Metadata Service endpoint to use. If specified it overrides EC2IMDSEndpointMode.
//
// ec2_metadata_service_endpoint=http://fd00:ec2::254
EC2IMDSEndpoint string
// Specifies that SDK clients must resolve a dual-stack endpoint for
// services.
//
// use_dualstack_endpoint=true
UseDualStackEndpoint endpoints.DualStackEndpointState
// Specifies that SDK clients must resolve a FIPS endpoint for
// services.
//
// use_fips_endpoint=true
UseFIPSEndpoint endpoints.FIPSEndpointState
} }
type sharedConfigFile struct { type sharedConfigFile struct {
@ -189,6 +237,8 @@ func loadSharedConfigIniFiles(filenames []string) ([]sharedConfigFile, error) {
} }
func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error { func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile string, files []sharedConfigFile, exOpts bool) error {
cfg.Profile = profile
// Trim files from the list that don't exist. // Trim files from the list that don't exist.
var skippedFiles int var skippedFiles int
var profileNotFoundErr error var profileNotFoundErr error
@ -217,9 +267,9 @@ func (cfg *sharedConfig) setFromIniFiles(profiles map[string]struct{}, profile s
cfg.clearAssumeRoleOptions() cfg.clearAssumeRoleOptions()
} else { } else {
// First time a profile has been seen, It must either be a assume role // First time a profile has been seen, It must either be a assume role
// or credentials. Assert if the credential type requires a role ARN, // credentials, or SSO. Assert if the credential type requires a role ARN,
// the ARN is also set. // the ARN is also set, or validate that the SSO configuration is complete.
if err := cfg.validateCredentialsRequireARN(profile); err != nil { if err := cfg.validateCredentialsConfig(profile); err != nil {
return err return err
} }
} }
@ -312,6 +362,22 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, e
} }
cfg.S3UsEast1RegionalEndpoint = sre cfg.S3UsEast1RegionalEndpoint = sre
} }
// AWS Single Sign-On (AWS SSO)
updateString(&cfg.SSOAccountID, section, ssoAccountIDKey)
updateString(&cfg.SSORegion, section, ssoRegionKey)
updateString(&cfg.SSORoleName, section, ssoRoleNameKey)
updateString(&cfg.SSOStartURL, section, ssoStartURL)
if err := updateEC2MetadataServiceEndpointMode(&cfg.EC2IMDSEndpointMode, section, ec2MetadataServiceEndpointModeKey); err != nil {
return fmt.Errorf("failed to load %s from shared config, %s, %v",
ec2MetadataServiceEndpointModeKey, file.Filename, err)
}
updateString(&cfg.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey)
updateUseDualStackEndpoint(&cfg.UseDualStackEndpoint, section, useDualStackEndpoint)
updateUseFIPSEndpoint(&cfg.UseFIPSEndpoint, section, useFIPSEndpointKey)
} }
updateString(&cfg.CredentialProcess, section, credentialProcessKey) updateString(&cfg.CredentialProcess, section, credentialProcessKey)
@ -342,6 +408,22 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, e
return nil return nil
} }
func updateEC2MetadataServiceEndpointMode(endpointMode *endpoints.EC2IMDSEndpointModeState, section ini.Section, key string) error {
if !section.Has(key) {
return nil
}
value := section.String(key)
return endpointMode.SetFromString(value)
}
func (cfg *sharedConfig) validateCredentialsConfig(profile string) error {
if err := cfg.validateCredentialsRequireARN(profile); err != nil {
return err
}
return nil
}
func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error { func (cfg *sharedConfig) validateCredentialsRequireARN(profile string) error {
var credSource string var credSource string
@ -378,12 +460,43 @@ func (cfg *sharedConfig) validateCredentialType() error {
return nil return nil
} }
func (cfg *sharedConfig) validateSSOConfiguration() error {
if !cfg.hasSSOConfiguration() {
return nil
}
var missing []string
if len(cfg.SSOAccountID) == 0 {
missing = append(missing, ssoAccountIDKey)
}
if len(cfg.SSORegion) == 0 {
missing = append(missing, ssoRegionKey)
}
if len(cfg.SSORoleName) == 0 {
missing = append(missing, ssoRoleNameKey)
}
if len(cfg.SSOStartURL) == 0 {
missing = append(missing, ssoStartURL)
}
if len(missing) > 0 {
return fmt.Errorf("profile %q is configured to use SSO but is missing required configuration: %s",
cfg.Profile, strings.Join(missing, ", "))
}
return nil
}
func (cfg *sharedConfig) hasCredentials() bool { func (cfg *sharedConfig) hasCredentials() bool {
switch { switch {
case len(cfg.SourceProfileName) != 0: case len(cfg.SourceProfileName) != 0:
case len(cfg.CredentialSource) != 0: case len(cfg.CredentialSource) != 0:
case len(cfg.CredentialProcess) != 0: case len(cfg.CredentialProcess) != 0:
case len(cfg.WebIdentityTokenFile) != 0: case len(cfg.WebIdentityTokenFile) != 0:
case cfg.hasSSOConfiguration():
case cfg.Creds.HasKeys(): case cfg.Creds.HasKeys():
default: default:
return false return false
@ -397,6 +510,10 @@ func (cfg *sharedConfig) clearCredentialOptions() {
cfg.CredentialProcess = "" cfg.CredentialProcess = ""
cfg.WebIdentityTokenFile = "" cfg.WebIdentityTokenFile = ""
cfg.Creds = credentials.Value{} cfg.Creds = credentials.Value{}
cfg.SSOAccountID = ""
cfg.SSORegion = ""
cfg.SSORoleName = ""
cfg.SSOStartURL = ""
} }
func (cfg *sharedConfig) clearAssumeRoleOptions() { func (cfg *sharedConfig) clearAssumeRoleOptions() {
@ -407,6 +524,18 @@ func (cfg *sharedConfig) clearAssumeRoleOptions() {
cfg.SourceProfileName = "" cfg.SourceProfileName = ""
} }
func (cfg *sharedConfig) hasSSOConfiguration() bool {
switch {
case len(cfg.SSOAccountID) != 0:
case len(cfg.SSORegion) != 0:
case len(cfg.SSORoleName) != 0:
case len(cfg.SSOStartURL) != 0:
default:
return false
}
return true
}
func oneOrNone(bs ...bool) bool { func oneOrNone(bs ...bool) bool {
var count int var count int
@ -566,3 +695,35 @@ func (e CredentialRequiresARNError) OrigErr() error {
func (e CredentialRequiresARNError) Error() string { func (e CredentialRequiresARNError) Error() string {
return awserr.SprintError(e.Code(), e.Message(), "", nil) return awserr.SprintError(e.Code(), e.Message(), "", nil)
} }
// updateEndpointDiscoveryType will only update the dst with the value in the section, if
// a valid key and corresponding EndpointDiscoveryType is found.
func updateUseDualStackEndpoint(dst *endpoints.DualStackEndpointState, section ini.Section, key string) {
if !section.Has(key) {
return
}
if section.Bool(key) {
*dst = endpoints.DualStackEndpointStateEnabled
} else {
*dst = endpoints.DualStackEndpointStateDisabled
}
return
}
// updateEndpointDiscoveryType will only update the dst with the value in the section, if
// a valid key and corresponding EndpointDiscoveryType is found.
func updateUseFIPSEndpoint(dst *endpoints.FIPSEndpointState, section ini.Section, key string) {
if !section.Has(key) {
return
}
if section.Bool(key) {
*dst = endpoints.FIPSEndpointStateEnabled
} else {
*dst = endpoints.FIPSEndpointStateDisabled
}
return
}

View File

@ -34,23 +34,23 @@ func (m mapRule) IsValid(value string) bool {
return ok return ok
} }
// whitelist is a generic rule for whitelisting // allowList is a generic rule for allow listing
type whitelist struct { type allowList struct {
rule rule
} }
// IsValid for whitelist checks if the value is within the whitelist // IsValid for allow list checks if the value is within the allow list
func (w whitelist) IsValid(value string) bool { func (w allowList) IsValid(value string) bool {
return w.rule.IsValid(value) return w.rule.IsValid(value)
} }
// blacklist is a generic rule for blacklisting // excludeList is a generic rule for exclude listing
type blacklist struct { type excludeList struct {
rule rule
} }
// IsValid for whitelist checks if the value is within the whitelist // IsValid for exclude list checks if the value is within the exclude list
func (b blacklist) IsValid(value string) bool { func (b excludeList) IsValid(value string) bool {
return !b.rule.IsValid(value) return !b.rule.IsValid(value)
} }

View File

@ -1,3 +1,4 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package v4 package v4

View File

@ -1,3 +1,4 @@
//go:build go1.7
// +build go1.7 // +build go1.7
package v4 package v4

View File

@ -1,3 +1,4 @@
//go:build go1.5
// +build go1.5 // +build go1.5
package v4 package v4

View File

@ -90,7 +90,7 @@ const (
) )
var ignoredHeaders = rules{ var ignoredHeaders = rules{
blacklist{ excludeList{
mapRule{ mapRule{
authorizationHeader: struct{}{}, authorizationHeader: struct{}{},
"User-Agent": struct{}{}, "User-Agent": struct{}{},
@ -99,9 +99,9 @@ var ignoredHeaders = rules{
}, },
} }
// requiredSignedHeaders is a whitelist for build canonical headers. // requiredSignedHeaders is a allow list for build canonical headers.
var requiredSignedHeaders = rules{ var requiredSignedHeaders = rules{
whitelist{ allowList{
mapRule{ mapRule{
"Cache-Control": struct{}{}, "Cache-Control": struct{}{},
"Content-Disposition": struct{}{}, "Content-Disposition": struct{}{},
@ -145,12 +145,13 @@ var requiredSignedHeaders = rules{
}, },
}, },
patterns{"X-Amz-Meta-"}, patterns{"X-Amz-Meta-"},
patterns{"X-Amz-Object-Lock-"},
} }
// allowedHoisting is a whitelist for build query headers. The boolean value // allowedHoisting is a allow list for build query headers. The boolean value
// represents whether or not it is a pattern. // represents whether or not it is a pattern.
var allowedQueryHoisting = inclusiveRules{ var allowedQueryHoisting = inclusiveRules{
blacklist{requiredSignedHeaders}, excludeList{requiredSignedHeaders},
patterns{"X-Amz-"}, patterns{"X-Amz-"},
} }
@ -417,7 +418,7 @@ var SignRequestHandler = request.NamedHandler{
// request handler should only be used with the SDK's built in service client's // request handler should only be used with the SDK's built in service client's
// API operation requests. // API operation requests.
// //
// This function should not be used on its on its own, but in conjunction with // This function should not be used on its own, but in conjunction with
// an AWS service client's API operation call. To sign a standalone request // an AWS service client's API operation call. To sign a standalone request
// not created by a service client's API operation method use the "Sign" or // not created by a service client's API operation method use the "Sign" or
// "Presign" functions of the "Signer" type. // "Presign" functions of the "Signer" type.
@ -633,21 +634,25 @@ func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
ctx.Query.Set("X-Amz-SignedHeaders", ctx.signedHeaders) ctx.Query.Set("X-Amz-SignedHeaders", ctx.signedHeaders)
} }
headerValues := make([]string, len(headers)) headerItems := make([]string, len(headers))
for i, k := range headers { for i, k := range headers {
if k == "host" { if k == "host" {
if ctx.Request.Host != "" { if ctx.Request.Host != "" {
headerValues[i] = "host:" + ctx.Request.Host headerItems[i] = "host:" + ctx.Request.Host
} else { } else {
headerValues[i] = "host:" + ctx.Request.URL.Host headerItems[i] = "host:" + ctx.Request.URL.Host
} }
} else { } else {
headerValues[i] = k + ":" + headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
strings.Join(ctx.SignedHeaderVals[k], ",") for i, v := range ctx.SignedHeaderVals[k] {
headerValues[i] = strings.TrimSpace(v)
}
headerItems[i] = k + ":" +
strings.Join(headerValues, ",")
} }
} }
stripExcessSpaces(headerValues) stripExcessSpaces(headerItems)
ctx.canonicalHeaders = strings.Join(headerValues, "\n") ctx.canonicalHeaders = strings.Join(headerItems, "\n")
} }
func (ctx *signingCtx) buildCanonicalString() { func (ctx *signingCtx) buildCanonicalString() {
@ -689,9 +694,12 @@ func (ctx *signingCtx) buildBodyDigest() error {
if hash == "" { if hash == "" {
includeSHA256Header := ctx.unsignedPayload || includeSHA256Header := ctx.unsignedPayload ||
ctx.ServiceName == "s3" || ctx.ServiceName == "s3" ||
ctx.ServiceName == "s3-object-lambda" ||
ctx.ServiceName == "glacier" ctx.ServiceName == "glacier"
s3Presign := ctx.isPresign && ctx.ServiceName == "s3" s3Presign := ctx.isPresign &&
(ctx.ServiceName == "s3" ||
ctx.ServiceName == "s3-object-lambda")
if ctx.unsignedPayload || s3Presign { if ctx.unsignedPayload || s3Presign {
hash = "UNSIGNED-PAYLOAD" hash = "UNSIGNED-PAYLOAD"

View File

@ -1,3 +1,4 @@
//go:build go1.8
// +build go1.8 // +build go1.8
package aws package aws

View File

@ -1,3 +1,4 @@
//go:build !go1.8
// +build !go1.8 // +build !go1.8
package aws package aws

View File

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "1.36.12" const SDKVersion = "1.44.129"

View File

@ -1,3 +1,4 @@
//go:build !go1.7
// +build !go1.7 // +build !go1.7
package context package context

View File

@ -13,17 +13,30 @@
// } // }
// //
// Below is the BNF that describes this parser // Below is the BNF that describes this parser
// Grammar: // Grammar:
// stmt -> value stmt' // stmt -> section | stmt'
// stmt' -> epsilon | op stmt // stmt' -> epsilon | expr
// value -> number | string | boolean | quoted_string // expr -> value (stmt)* | equal_expr (stmt)*
// equal_expr -> value ( ':' | '=' ) equal_expr'
// equal_expr' -> number | string | quoted_string
// quoted_string -> " quoted_string'
// quoted_string' -> string quoted_string_end
// quoted_string_end -> "
// //
// section -> [ section' // section -> [ section'
// section' -> value section_close // section' -> section_value section_close
// section_close -> ] // section_value -> number | string_subset | boolean | quoted_string_subset
// quoted_string_subset -> " quoted_string_subset'
// quoted_string_subset' -> string_subset quoted_string_end
// quoted_string_subset -> "
// section_close -> ]
// //
// SkipState will skip (NL WS)+ // value -> number | string_subset | boolean
// string -> ? UTF-8 Code-Points except '\n' (U+000A) and '\r\n' (U+000D U+000A) ?
// string_subset -> ? Code-points excepted by <string> grammar except ':' (U+003A), '=' (U+003D), '[' (U+005B), and ']' (U+005D) ?
// //
// comment -> # comment' | ; comment' // SkipState will skip (NL WS)+
// comment' -> epsilon | value //
// comment -> # comment' | ; comment'
// comment' -> epsilon | value
package ini package ini

View File

@ -1,3 +1,4 @@
//go:build gofuzz
// +build gofuzz // +build gofuzz
package ini package ini

View File

@ -5,9 +5,12 @@ import (
"io" "io"
) )
// ParseState represents the current state of the parser.
type ParseState uint
// State enums for the parse table // State enums for the parse table
const ( const (
InvalidState = iota InvalidState ParseState = iota
// stmt -> value stmt' // stmt -> value stmt'
StatementState StatementState
// stmt' -> MarkComplete | op stmt // stmt' -> MarkComplete | op stmt
@ -36,8 +39,8 @@ const (
) )
// parseTable is a state machine to dictate the grammar above. // parseTable is a state machine to dictate the grammar above.
var parseTable = map[ASTKind]map[TokenType]int{ var parseTable = map[ASTKind]map[TokenType]ParseState{
ASTKindStart: map[TokenType]int{ ASTKindStart: {
TokenLit: StatementState, TokenLit: StatementState,
TokenSep: OpenScopeState, TokenSep: OpenScopeState,
TokenWS: SkipTokenState, TokenWS: SkipTokenState,
@ -45,7 +48,7 @@ var parseTable = map[ASTKind]map[TokenType]int{
TokenComment: CommentState, TokenComment: CommentState,
TokenNone: TerminalState, TokenNone: TerminalState,
}, },
ASTKindCommentStatement: map[TokenType]int{ ASTKindCommentStatement: {
TokenLit: StatementState, TokenLit: StatementState,
TokenSep: OpenScopeState, TokenSep: OpenScopeState,
TokenWS: SkipTokenState, TokenWS: SkipTokenState,
@ -53,7 +56,7 @@ var parseTable = map[ASTKind]map[TokenType]int{
TokenComment: CommentState, TokenComment: CommentState,
TokenNone: MarkCompleteState, TokenNone: MarkCompleteState,
}, },
ASTKindExpr: map[TokenType]int{ ASTKindExpr: {
TokenOp: StatementPrimeState, TokenOp: StatementPrimeState,
TokenLit: ValueState, TokenLit: ValueState,
TokenSep: OpenScopeState, TokenSep: OpenScopeState,
@ -62,13 +65,15 @@ var parseTable = map[ASTKind]map[TokenType]int{
TokenComment: CommentState, TokenComment: CommentState,
TokenNone: MarkCompleteState, TokenNone: MarkCompleteState,
}, },
ASTKindEqualExpr: map[TokenType]int{ ASTKindEqualExpr: {
TokenLit: ValueState, TokenLit: ValueState,
TokenWS: SkipTokenState, TokenSep: ValueState,
TokenNL: SkipState, TokenOp: ValueState,
TokenNone: SkipState, TokenWS: SkipTokenState,
TokenNL: SkipState,
TokenNone: SkipState,
}, },
ASTKindStatement: map[TokenType]int{ ASTKindStatement: {
TokenLit: SectionState, TokenLit: SectionState,
TokenSep: CloseScopeState, TokenSep: CloseScopeState,
TokenWS: SkipTokenState, TokenWS: SkipTokenState,
@ -76,9 +81,9 @@ var parseTable = map[ASTKind]map[TokenType]int{
TokenComment: CommentState, TokenComment: CommentState,
TokenNone: MarkCompleteState, TokenNone: MarkCompleteState,
}, },
ASTKindExprStatement: map[TokenType]int{ ASTKindExprStatement: {
TokenLit: ValueState, TokenLit: ValueState,
TokenSep: OpenScopeState, TokenSep: ValueState,
TokenOp: ValueState, TokenOp: ValueState,
TokenWS: ValueState, TokenWS: ValueState,
TokenNL: MarkCompleteState, TokenNL: MarkCompleteState,
@ -86,14 +91,14 @@ var parseTable = map[ASTKind]map[TokenType]int{
TokenNone: TerminalState, TokenNone: TerminalState,
TokenComma: SkipState, TokenComma: SkipState,
}, },
ASTKindSectionStatement: map[TokenType]int{ ASTKindSectionStatement: {
TokenLit: SectionState, TokenLit: SectionState,
TokenOp: SectionState, TokenOp: SectionState,
TokenSep: CloseScopeState, TokenSep: CloseScopeState,
TokenWS: SectionState, TokenWS: SectionState,
TokenNL: SkipTokenState, TokenNL: SkipTokenState,
}, },
ASTKindCompletedSectionStatement: map[TokenType]int{ ASTKindCompletedSectionStatement: {
TokenWS: SkipTokenState, TokenWS: SkipTokenState,
TokenNL: SkipTokenState, TokenNL: SkipTokenState,
TokenLit: StatementState, TokenLit: StatementState,
@ -101,7 +106,7 @@ var parseTable = map[ASTKind]map[TokenType]int{
TokenComment: CommentState, TokenComment: CommentState,
TokenNone: MarkCompleteState, TokenNone: MarkCompleteState,
}, },
ASTKindSkipStatement: map[TokenType]int{ ASTKindSkipStatement: {
TokenLit: StatementState, TokenLit: StatementState,
TokenSep: OpenScopeState, TokenSep: OpenScopeState,
TokenWS: SkipTokenState, TokenWS: SkipTokenState,
@ -205,18 +210,6 @@ loop:
case ValueState: case ValueState:
// ValueState requires the previous state to either be an equal expression // ValueState requires the previous state to either be an equal expression
// or an expression statement. // or an expression statement.
//
// This grammar occurs when the RHS is a number, word, or quoted string.
// equal_expr -> lit op equal_expr'
// equal_expr' -> number | string | quoted_string
// quoted_string -> " quoted_string'
// quoted_string' -> string quoted_string_end
// quoted_string_end -> "
//
// otherwise
// expr_stmt -> equal_expr (expr_stmt')*
// expr_stmt' -> ws S | op S | MarkComplete
// S -> equal_expr' expr_stmt'
switch k.Kind { switch k.Kind {
case ASTKindEqualExpr: case ASTKindEqualExpr:
// assigning a value to some key // assigning a value to some key
@ -243,7 +236,7 @@ loop:
} }
children[len(children)-1] = rhs children[len(children)-1] = rhs
k.SetChildren(children) root.SetChildren(children)
stack.Push(k) stack.Push(k)
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"unicode"
) )
var ( var (
@ -18,7 +19,7 @@ var literalValues = [][]rune{
func isBoolValue(b []rune) bool { func isBoolValue(b []rune) bool {
for _, lv := range literalValues { for _, lv := range literalValues {
if isLitValue(lv, b) { if isCaselessLitValue(lv, b) {
return true return true
} }
} }
@ -39,6 +40,21 @@ func isLitValue(want, have []rune) bool {
return true return true
} }
// isCaselessLitValue is a caseless value comparison, assumes want is already lower-cased for efficiency.
func isCaselessLitValue(want, have []rune) bool {
if len(have) < len(want) {
return false
}
for i := 0; i < len(want); i++ {
if want[i] != unicode.ToLower(have[i]) {
return false
}
}
return true
}
// isNumberValue will return whether not the leading characters in // isNumberValue will return whether not the leading characters in
// a byte slice is a number. A number is delimited by whitespace or // a byte slice is a number. A number is delimited by whitespace or
// the newline token. // the newline token.
@ -177,7 +193,7 @@ func newValue(t ValueType, base int, raw []rune) (Value, error) {
case QuotedStringType: case QuotedStringType:
v.str = string(raw[1 : len(raw)-1]) v.str = string(raw[1 : len(raw)-1])
case BoolType: case BoolType:
v.boolean = runeCompare(v.raw, runesTrue) v.boolean = isCaselessLitValue(runesTrue, v.raw)
} }
// issue 2253 // issue 2253

View File

@ -57,7 +57,7 @@ func getBoolValue(b []rune) (int, error) {
continue continue
} }
if isLitValue(lv, b) { if isCaselessLitValue(lv, b) {
n = len(lv) n = len(lv)
} }
} }

View File

@ -50,7 +50,10 @@ func (v *DefaultVisitor) VisitExpr(expr AST) error {
rhs := children[1] rhs := children[1]
if rhs.Root.Type() != TokenLit { // The right-hand value side the equality expression is allowed to contain '[', ']', ':', '=' in the values.
// If the token is not either a literal or one of the token types that identifies those four additional
// tokens then error.
if !(rhs.Root.Type() == TokenLit || rhs.Root.Type() == TokenOp || rhs.Root.Type() == TokenSep) {
return NewParseError("unexpected token type") return NewParseError("unexpected token type")
} }

View File

@ -6,6 +6,7 @@ go_library(
"accesspoint_arn.go", "accesspoint_arn.go",
"arn.go", "arn.go",
"outpost_arn.go", "outpost_arn.go",
"s3_object_lambda_arn.go",
], ],
importmap = "peridot.resf.org/vendor/github.com/aws/aws-sdk-go/internal/s3shared/arn", importmap = "peridot.resf.org/vendor/github.com/aws/aws-sdk-go/internal/s3shared/arn",
importpath = "github.com/aws/aws-sdk-go/internal/s3shared/arn", importpath = "github.com/aws/aws-sdk-go/internal/s3shared/arn",

View File

@ -7,6 +7,21 @@ import (
"github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/arn"
) )
var supportedServiceARN = []string{
"s3",
"s3-outposts",
"s3-object-lambda",
}
func isSupportedServiceARN(service string) bool {
for _, name := range supportedServiceARN {
if name == service {
return true
}
}
return false
}
// Resource provides the interfaces abstracting ARNs of specific resource // Resource provides the interfaces abstracting ARNs of specific resource
// types. // types.
type Resource interface { type Resource interface {
@ -29,9 +44,14 @@ func ParseResource(s string, resParser ResourceParser) (resARN Resource, err err
return nil, InvalidARNError{ARN: a, Reason: "partition not set"} return nil, InvalidARNError{ARN: a, Reason: "partition not set"}
} }
if a.Service != "s3" && a.Service != "s3-outposts" { if !isSupportedServiceARN(a.Service) {
return nil, InvalidARNError{ARN: a, Reason: "service is not supported"} return nil, InvalidARNError{ARN: a, Reason: "service is not supported"}
} }
if strings.HasPrefix(a.Region, "fips-") || strings.HasSuffix(a.Region, "-fips") {
return nil, InvalidARNError{ARN: a, Reason: "FIPS region not allowed in ARN"}
}
if len(a.Resource) == 0 { if len(a.Resource) == 0 {
return nil, InvalidARNError{ARN: a, Reason: "resource not set"} return nil, InvalidARNError{ARN: a, Reason: "resource not set"}
} }

Some files were not shown because too many files have changed in this diff Show More