diff --git a/base/go/fs.go b/base/go/fs.go index 22f11114..e750bb2b 100644 --- a/base/go/fs.go +++ b/base/go/fs.go @@ -1 +1,80 @@ package base + +import ( + "embed" + "github.com/pkg/errors" + "io" + "os" + "strings" +) + +// EmbedFSToOSFS re-creates the structure of the embed.FS in the OS filesystem. +// It is mostly useful for testing but other uses are possible. +// This function takes a prefix / directory and an embed.FS and creates the +// corresponding directory structure in the OS filesystem. +// Returns all created file paths. +func EmbedFSToOSFS(prefix string, fs embed.FS, fsDirectory string) ([]string, error) { + // Make sure the prefix exists as a directory + err := os.MkdirAll(prefix, 0755) + if err != nil { + return nil, errors.Wrap(err, "could not create directory") + } + + // Read the directory entries + entries, err := fs.ReadDir(fsDirectory) + if err != nil { + return nil, errors.Wrap(err, "could not read directory") + } + + var res []string + + // Iterate over the entries + for _, entry := range entries { + // Create the full path + full := prefix + "/" + entry.Name() + + // Check if the entry is a directory + if entry.IsDir() { + // If it is, recurse + res2, err := EmbedFSToOSFS(full, fs, fsDirectory+"/"+entry.Name()) + if err != nil { + return nil, errors.Wrap(err, "could not recurse") + } + res = append(res, res2...) + } else { + // If not, create the file + f, err := os.Create(full) + if err != nil { + return nil, errors.Wrap(err, "could not create file") + } + + // Open the file in the embed.FS + file, err := fs.Open(strings.TrimPrefix(fsDirectory+"/"+entry.Name(), "./")) + if err != nil { + return nil, errors.Wrap(err, "could not open file") + } + + // Copy the file + _, err = io.Copy(f, file) + if err != nil { + return nil, errors.Wrap(err, "could not copy file") + } + + // Close the OS file + err = f.Close() + if err != nil { + return nil, errors.Wrap(err, "could not close file") + } + + // Close the embed.FS file + err = file.Close() + if err != nil { + return nil, errors.Wrap(err, "could not close file") + } + + res = append(res, full) + } + } + + return res, nil +} diff --git a/scripts/govendor b/scripts/govendor index 6d6aafab..113b4ce7 100755 --- a/scripts/govendor +++ b/scripts/govendor @@ -1,7 +1,9 @@ #!/usr/bin/env bash -rm -rf vendor +set -euo pipefail + go mod tidy -e +rm -rf vendor go mod vendor -e find vendor -name "BUILD.bazel" -delete find vendor -name "BUILD" -delete diff --git a/tools/mothership/admin/rpc/BUILD b/tools/mothership/admin/rpc/BUILD index 38800ec4..634900c8 100644 --- a/tools/mothership/admin/rpc/BUILD +++ b/tools/mothership/admin/rpc/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "rpc", @@ -12,7 +12,6 @@ go_library( "//base/go", "//tools/mothership/db", "//tools/mothership/proto/admin/v1:pb", - "//vendor/github.com/google/uuid", "//vendor/go.ciq.dev/pika", "@go_googleapis//google/rpc:errdetails_go_proto", "@org_golang_google_grpc//:go_default_library", @@ -21,3 +20,26 @@ go_library( "@org_golang_google_protobuf//types/known/emptypb", ], ) + +go_test( + name = "rpc_test", + size = "small", + srcs = [ + "main_test.go", + "worker_test.go", + ], + embed = [":rpc"], + deps = [ + "//base/go", + "//tools/mothership/db", + "//tools/mothership/migrations", + "//tools/mothership/proto/admin/v1:pb", + "//vendor/github.com/stretchr/testify/require", + "//vendor/github.com/testcontainers/testcontainers-go", + "//vendor/github.com/testcontainers/testcontainers-go/modules/postgres", + "//vendor/github.com/testcontainers/testcontainers-go/wait", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//metadata", + "@org_golang_google_grpc//status", + ], +) diff --git a/tools/mothership/admin/rpc/main_test.go b/tools/mothership/admin/rpc/main_test.go index 3cea84f7..4e4f9664 100644 --- a/tools/mothership/admin/rpc/main_test.go +++ b/tools/mothership/admin/rpc/main_test.go @@ -1 +1,85 @@ package mothershipadmin_rpc + +import ( + "context" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + base "go.resf.org/peridot/base/go" + "go.resf.org/peridot/tools/mothership/migrations" + "google.golang.org/grpc/metadata" + "os" + "testing" + "time" +) + +var ( + s *Server + userInfo base.UserInfo +) + +func TestMain(m *testing.M) { + // Create temporary file + dir, err := os.MkdirTemp("", "test-db-*") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + + scripts, err := base.EmbedFSToOSFS(dir, migrations.UpSQLs, ".") + if err != nil { + panic(err) + } + + ctx := context.Background() + pgContainer, err := postgres.RunContainer( + ctx, + testcontainers.WithImage("postgres:15.3-alpine"), + postgres.WithInitScripts(scripts...), + postgres.WithDatabase("mshiptest"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait. + ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second), + ), + ) + if err != nil { + panic(err) + } + defer pgContainer.Terminate(ctx) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + if err != nil { + panic(err) + } + + db, err := base.NewDB(connStr) + if err != nil { + panic(err) + } + + provider := base.NewTestOidcProvider(&userInfo) + + interceptorDetails := &base.OidcInterceptorDetails{ + Provider: provider, + Group: "", + } + s, err = NewServer(db, interceptorDetails) + if err != nil { + panic(err) + } + + os.Exit(m.Run()) +} + +func testContext() context.Context { + mdMap := map[string]string{} + if userInfo != nil { + mdMap["authorization"] = "bearer " + userInfo.Subject() + } + md := metadata.New(mdMap) + ctx := metadata.NewIncomingContext(context.Background(), md) + return context.WithValue(ctx, "user", userInfo) +} diff --git a/tools/mothership/admin/rpc/worker_test.go b/tools/mothership/admin/rpc/worker_test.go index 3cea84f7..b9c8d125 100644 --- a/tools/mothership/admin/rpc/worker_test.go +++ b/tools/mothership/admin/rpc/worker_test.go @@ -1 +1,40 @@ package mothershipadmin_rpc + +import ( + "github.com/stretchr/testify/require" + base "go.resf.org/peridot/base/go" + mshipadminpb "go.resf.org/peridot/tools/mothership/admin/pb" + mothership_db "go.resf.org/peridot/tools/mothership/db" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "testing" +) + +func TestGetWorker_Empty(t *testing.T) { + require.Nil(t, base.Q[mothership_db.Worker](s.db).Delete()) + worker, err := s.GetWorker(testContext(), &mshipadminpb.GetWorkerRequest{}) + require.NotNil(t, err) + require.Nil(t, worker) + expectedErr := status.Error(codes.NotFound, "worker not found") + require.Equal(t, expectedErr.Error(), err.Error()) +} + +func TestGetWorker_One(t *testing.T) { + require.Nil(t, base.Q[mothership_db.Worker](s.db).Delete()) + require.Nil(t, base.Q[mothership_db.Worker](s.db).Create(&mothership_db.Worker{ + Name: "test", + WorkerID: "test-id", + ApiSecret: "secret", + })) + defer func() { + require.Nil(t, base.Q[mothership_db.Worker](s.db).Delete()) + }() + + worker, err := s.GetWorker(testContext(), &mshipadminpb.GetWorkerRequest{ + Name: "test", + }) + require.Nil(t, err) + require.Equal(t, "test", worker.Name) + require.Equal(t, "test-id", worker.WorkerId) + require.Empty(t, worker.ApiSecret) +} diff --git a/tools/mothership/migrations/000001_init.up.sql b/tools/mothership/migrations/000001_init.up.sql index dee22d67..6f93cfd9 100644 --- a/tools/mothership/migrations/000001_init.up.sql +++ b/tools/mothership/migrations/000001_init.up.sql @@ -17,4 +17,4 @@ CREATE TABLE entries worker_id VARCHAR(255) REFERENCES workers (worker_id), batch_name VARCHAR(255), user_email TEXT -) +); diff --git a/tools/mothership/migrations/migrations.go b/tools/mothership/migrations/migrations.go index a6ea3eef..c7f0908b 100644 --- a/tools/mothership/migrations/migrations.go +++ b/tools/mothership/migrations/migrations.go @@ -1 +1,6 @@ package migrations + +import "embed" + +//go:embed *.up.sql +var UpSQLs embed.FS