Implement session endpoints
This commit is contained in:
parent
214de1a86f
commit
76a0a67c8d
11 changed files with 291 additions and 29 deletions
|
@ -31,6 +31,9 @@
|
|||
- "core::0133::request-id-field"
|
||||
# We don't require update mask support
|
||||
- "core::0134::request-mask-required"
|
||||
# We have custom methods that only return a single resource
|
||||
- "core::0136::response-message-name"
|
||||
- "core::0136::request-message-name"
|
||||
- included_paths:
|
||||
- "third_party/**/*.proto"
|
||||
- "vendor/**/*.proto"
|
||||
|
|
17
kojicompat/logout.go
Normal file
17
kojicompat/logout.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package kojicompat
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rocky-linux/peridot/v2/xmlrpc"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func (s *Server) Logout(ctx context.Context, _ *xmlrpc.Nothing) (*xmlrpc.Nothing, error) {
|
||||
_, err := s.client.DestroySession(ctx, &emptypb.Empty{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &xmlrpc.Nothing{}, nil
|
||||
}
|
|
@ -1,10 +1,11 @@
|
|||
CREATE TABLE users
|
||||
(
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
update_time TIMESTAMPTZ,
|
||||
username VARCHAR(255) NOT NULL,
|
||||
password_hash VARCHAR(255)
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
password_hash VARCHAR(255),
|
||||
type INT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE user_krb_principals
|
||||
|
@ -30,11 +31,12 @@ CREATE TABLE sessions
|
|||
create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id BIGINT REFERENCES users (id) NOT NULL,
|
||||
expire_time TIMESTAMPTZ NOT NULL,
|
||||
last_renew_time TIMESTAMPTZ NOT NULL,
|
||||
last_renew_time TIMESTAMPTZ,
|
||||
renew_count INT NOT NULL DEFAULT 0,
|
||||
key VARCHAR(255) NOT NULL,
|
||||
auth_type INT NOT NULL,
|
||||
host_ip VARCHAR(255) NOT NULL,
|
||||
call_num INT NOT NULL,
|
||||
closed BOOLEAN NOT NULL DEFAULT FALSE
|
||||
closed BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
krb_principal VARCHAR(255)
|
||||
)
|
||||
|
|
|
@ -2,17 +2,122 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
peridotpb "github.com/rocky-linux/peridot/v2/peridot/v2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
PikaTableName string `pika:"sessions"`
|
||||
PikaDefaultOrderBy string `pika:"-create_time"`
|
||||
|
||||
ID int64 `db:"id" pika:"omitempty"`
|
||||
CreateTime time.Time `db:"create_time" pika:"omitempty"`
|
||||
UserID int64 `db:"user_id"`
|
||||
ExpireTime time.Time `db:"expire_time"`
|
||||
LastRenewTime sql.Null[time.Time] `db:"last_renew_time" pika:"omitempty"`
|
||||
RenewCount int `db:"renew_count"`
|
||||
Key string `db:"key"`
|
||||
AuthType peridotpb.Session_AuthType `db:"auth_type"`
|
||||
HostIP string `db:"host_ip"`
|
||||
CallNum int `db:"call_num"`
|
||||
Closed bool `db:"closed"`
|
||||
KrbPrincipal sql.Null[string] `db:"krb_principal"`
|
||||
}
|
||||
|
||||
func (s Session) GetID() int64 {
|
||||
return int64(s.ID)
|
||||
}
|
||||
|
||||
func (s Session) ToPB(u *User) *peridotpb.Session {
|
||||
user := u.ToPB()
|
||||
return &peridotpb.Session{
|
||||
User: user,
|
||||
AuthType: s.AuthType,
|
||||
KerberosPrincipal: NullString(s.KrbPrincipal),
|
||||
}
|
||||
}
|
||||
|
||||
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func generateKey(userID int64) string {
|
||||
b := make([]byte, 19)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return fmt.Sprintf("%d-%s", userID, string(b))
|
||||
}
|
||||
|
||||
func (s *Server) NewSession(ctx context.Context, req *peridotpb.NewSessionRequest) (*peridotpb.NewSessionResponse, error) {
|
||||
user := getUserFromContext(ctx)
|
||||
slog.Info("NewSession called by user", "user", user.Username)
|
||||
key := generateKey(user.GetID())
|
||||
|
||||
remotePeer, _ := peer.FromContext(ctx)
|
||||
ip, _, err := net.SplitHostPort(remotePeer.Addr.String())
|
||||
if err != nil {
|
||||
slog.Error("Failed to split host and port", "error", err)
|
||||
return nil, status.Error(codes.Internal, "Failed to split host and port")
|
||||
}
|
||||
|
||||
var authType peridotpb.Session_AuthType
|
||||
if _, ok := ctx.Value(ContextKerberosPrincipalKey{}).(*UserKerberosPrincipal); ok {
|
||||
authType = peridotpb.Session_AUTH_TYPE_GSSAPI
|
||||
} else {
|
||||
authType = peridotpb.Session_AUTH_TYPE_UNSPECIFIED
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
UserID: user.GetID(),
|
||||
Key: key,
|
||||
HostIP: ip,
|
||||
ExpireTime: time.Now().Add(time.Hour),
|
||||
CallNum: 1,
|
||||
AuthType: authType,
|
||||
}
|
||||
err = Q[Session](s.db).Createx(ctx, session)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create session", "error", err)
|
||||
return nil, status.Error(codes.Internal, "Failed to create session")
|
||||
}
|
||||
|
||||
return &peridotpb.NewSessionResponse{
|
||||
SessionId: 1,
|
||||
SessionKey: "session-key",
|
||||
SessionId: int32(session.ID),
|
||||
SessionKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CurrentSession(ctx context.Context, req *emptypb.Empty) (*peridotpb.Session, error) {
|
||||
session, _ := ctx.Value(ContextSessionKey{}).(*Session)
|
||||
if session == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "No session")
|
||||
}
|
||||
|
||||
user := getUserFromContext(ctx)
|
||||
|
||||
return session.ToPB(user), nil
|
||||
}
|
||||
|
||||
func (s *Server) DestroySession(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
session, _ := ctx.Value(ContextSessionKey{}).(*Session)
|
||||
if session == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "No session")
|
||||
}
|
||||
|
||||
session.Closed = true
|
||||
err := Q[Session](s.db).Ux(ctx, session)
|
||||
if err != nil {
|
||||
slog.Error("Failed to close session", "error", err)
|
||||
return nil, status.Error(codes.Internal, "Failed to close session")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
|
|
@ -173,7 +173,10 @@ func DefaultServeMuxOptions(s ...GRPCServer) []runtime.ServeMuxOption {
|
|||
runtime.WithIncomingHeaderMatcher(func(s string) (string, bool) {
|
||||
switch strings.ToLower(s) {
|
||||
case "authorization",
|
||||
"cookie":
|
||||
"cookie",
|
||||
"peridot-session-id",
|
||||
"peridot-session-key",
|
||||
"peridot-session-call-num":
|
||||
return s, true
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"log/slog"
|
||||
"strings"
|
||||
|
||||
peridotpb "github.com/rocky-linux/peridot/v2/peridot/v2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -19,6 +20,8 @@ import (
|
|||
"gopkg.in/jcmturner/gokrb5.v7/types"
|
||||
)
|
||||
|
||||
type ContextKerberosPrincipalKey struct{}
|
||||
|
||||
// From https://github.com/jcmturner/gokrb5/blob/855dbc707a37a21467aef6c0245fcf3328dc39ed/spnego/http.go#L196
|
||||
const (
|
||||
// spnegoNegTokenRespKRBAcceptCompleted - The response on successful authentication always has this header. Capturing as const so we don't have marshaling and encoding overhead.
|
||||
|
@ -112,8 +115,9 @@ func (s *Server) kerberosInterceptor(next InterceptorFunc) InterceptorFunc {
|
|||
if user == nil {
|
||||
user = &User{
|
||||
Username: id.UserName(),
|
||||
Type: peridotpb.User_TYPE_USER,
|
||||
}
|
||||
err = Q[User](s.db).Create(ctx, user)
|
||||
err = Q[User](s.db).Createx(ctx, user)
|
||||
if err != nil {
|
||||
slog.Error("Error creating user in DB", "error", err)
|
||||
return nil, status.Error(codes.Internal, "Error creating user in DB")
|
||||
|
@ -132,7 +136,7 @@ func (s *Server) kerberosInterceptor(next InterceptorFunc) InterceptorFunc {
|
|||
Principal: principal,
|
||||
UserID: user.ID,
|
||||
}
|
||||
err = Q[UserKerberosPrincipal](s.db).Create(ctx, krbPrincipal)
|
||||
err = Q[UserKerberosPrincipal](s.db).Createx(ctx, krbPrincipal)
|
||||
if err != nil {
|
||||
slog.Error("Error creating kerberos principal in DB", "error", err)
|
||||
return nil, status.Error(codes.Internal, "Error creating kerberos principal in DB")
|
||||
|
@ -140,6 +144,7 @@ func (s *Server) kerberosInterceptor(next InterceptorFunc) InterceptorFunc {
|
|||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ContextUserKey{}, user)
|
||||
ctx = context.WithValue(ctx, ContextKerberosPrincipalKey{}, krbPrincipal)
|
||||
}
|
||||
|
||||
return next(ctx, req, usi, handler)
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||
"go.ciq.dev/pika"
|
||||
)
|
||||
|
||||
|
@ -16,6 +15,7 @@ type Pika[T any] interface {
|
|||
Fx(keyval ...any) Pika[T]
|
||||
Dx(ctx context.Context, x any) error
|
||||
Transactionx(ctx context.Context) (Pika[T], error)
|
||||
Createx(ctx context.Context, x *T) error
|
||||
}
|
||||
|
||||
type innerDB[T any] struct {
|
||||
|
@ -27,15 +27,6 @@ type nameInterface interface {
|
|||
GetID() int64
|
||||
}
|
||||
|
||||
func NewDBArgs(keyval ...any) *orderedmap.OrderedMap[string, any] {
|
||||
args := pika.NewArgs()
|
||||
for i := 0; i < len(keyval); i += 2 {
|
||||
args.Set(keyval[i].(string), keyval[i+1])
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func Q[T any](db DB) Pika[T] {
|
||||
return &innerDB[T]{pika.PSQLQuery[T](pika.NewPostgreSQLFromDB(db.GetDB())), db}
|
||||
}
|
||||
|
@ -109,8 +100,7 @@ func (inner *innerDB[T]) Ux(ctx context.Context, x any) error {
|
|||
return qs.Update(ctx, y)
|
||||
}
|
||||
|
||||
func (inner *innerDB[T]) Createx(x *T) error {
|
||||
ctx := context.TODO()
|
||||
func (inner *innerDB[T]) Createx(ctx context.Context, x *T) error {
|
||||
ts := pika.NewPostgreSQLFromDB(inner.db.GetDB())
|
||||
err := ts.Begin(ctx)
|
||||
if err != nil {
|
||||
|
@ -153,3 +143,16 @@ func (inner *innerDB[T]) Createx(x *T) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable the default Create, Update, Delete methods
|
||||
func (inner *innerDB[T]) Create(ctx context.Context, value *T, options ...pika.CreateOption) error {
|
||||
return errors.New("Create method is disabled")
|
||||
}
|
||||
|
||||
func (inner *innerDB[T]) Update(ctx context.Context, value *T) error {
|
||||
return errors.New("Update method is disabled")
|
||||
}
|
||||
|
||||
func (inner *innerDB[T]) Delete(ctx context.Context) error {
|
||||
return errors.New("Delete method is disabled")
|
||||
}
|
||||
|
|
23
server/protobuf.go
Normal file
23
server/protobuf.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
func NullTimestampPb(t sql.Null[time.Time]) *timestamppb.Timestamp {
|
||||
if t.Valid {
|
||||
return timestamppb.New(t.V)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NullString(s sql.Null[string]) *wrapperspb.StringValue {
|
||||
if s.Valid {
|
||||
return wrapperspb.String(s.V)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -70,6 +70,7 @@ func (s *Server) interceptor(ctx context.Context, req interface{}, usi *grpc.Una
|
|||
n := EndInterceptor
|
||||
n = s.authEnforcerInterceptor(n)
|
||||
n = s.kerberosInterceptor(n)
|
||||
n = s.sessionAuthInterceptor(n)
|
||||
|
||||
return n(ctx, req, usi, handler)
|
||||
}
|
||||
|
|
78
server/session_auth.go
Normal file
78
server/session_auth.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type ContextSessionKey struct{}
|
||||
|
||||
func (s *Server) sessionAuthInterceptor(next InterceptorFunc) InterceptorFunc {
|
||||
return func(ctx context.Context, req any, usi *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
// If already authenticated, skip
|
||||
if getUserFromContext(ctx) != nil {
|
||||
return next(ctx, req, usi, handler)
|
||||
}
|
||||
|
||||
md, _ := metadata.FromIncomingContext(ctx)
|
||||
sessionID, _ := md["peridot-session-id"]
|
||||
sessionKey, _ := md["peridot-session-key"]
|
||||
callNum, _ := md["peridot-session-call-num"]
|
||||
if len(sessionID) != 1 || len(sessionKey) != 1 || len(callNum) != 1 {
|
||||
return next(ctx, req, usi, handler)
|
||||
}
|
||||
|
||||
// Check if the session is valid
|
||||
session, err := Q[Session](s.db).Fx("id", sessionID[0], "key", sessionKey[0]).GetOrNil(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "Error checking session")
|
||||
}
|
||||
if session == nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "Negotiate")
|
||||
}
|
||||
|
||||
// Verify not closed
|
||||
if session.Closed {
|
||||
return nil, status.Error(codes.Unauthenticated, "Negotiate")
|
||||
}
|
||||
|
||||
// Verify expiry
|
||||
if session.ExpireTime.Before(time.Now()) {
|
||||
return nil, status.Error(codes.Unauthenticated, "Negotiate")
|
||||
}
|
||||
|
||||
// Verify call number
|
||||
callNumInt, err := strconv.Atoi(callNum[0])
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "Invalid call number")
|
||||
}
|
||||
if session.CallNum != callNumInt {
|
||||
return nil, status.Error(codes.Internal, "Invalid call number")
|
||||
}
|
||||
|
||||
// Increase call number
|
||||
session.CallNum++
|
||||
err = Q[Session](s.db).Ux(ctx, session)
|
||||
if err != nil {
|
||||
slog.Error("Error updating session", "error", err)
|
||||
return nil, status.Error(codes.Internal, "Error updating session")
|
||||
}
|
||||
|
||||
// Set user in context
|
||||
user, err := Q[User](s.db).Fx("id", session.UserID).GetOrNil(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "Error getting user")
|
||||
}
|
||||
ctx = context.WithValue(ctx, ContextUserKey{}, user)
|
||||
ctx = context.WithValue(ctx, ContextSessionKey{}, session)
|
||||
|
||||
return next(ctx, req, usi, handler)
|
||||
}
|
||||
}
|
|
@ -3,24 +3,46 @@ package server
|
|||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
peridotpb "github.com/rocky-linux/peridot/v2/peridot/v2"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
PikaTableName string `pika:"users"`
|
||||
PikaDefaultOrderBy string `pika:"-create_time"`
|
||||
|
||||
ID int `db:"id"`
|
||||
CreateTime time.Time `db:"create_time"`
|
||||
UpdateTime sql.Null[time.Time] `db:"update_time"`
|
||||
ID int `db:"id" pika:"omitempty"`
|
||||
CreateTime time.Time `db:"create_time" pika:"omitempty"`
|
||||
UpdateTime sql.Null[time.Time] `db:"update_time" pika:"omitempty"`
|
||||
Username string `db:"username"`
|
||||
PasswordHash sql.Null[string] `db:"password_hash"`
|
||||
PasswordHash sql.Null[string] `db:"password_hash" pika:"omitempty"`
|
||||
Type peridotpb.User_Type `db:"type"`
|
||||
}
|
||||
|
||||
func (u User) GetID() int64 {
|
||||
return int64(u.ID)
|
||||
}
|
||||
|
||||
func (u User) ToPB() *peridotpb.User {
|
||||
return &peridotpb.User{
|
||||
Id: int64(u.ID),
|
||||
CreateTime: timestamppb.New(u.CreateTime),
|
||||
UpdateTime: NullTimestampPb(u.UpdateTime),
|
||||
Username: u.Username,
|
||||
Type: u.Type,
|
||||
}
|
||||
}
|
||||
|
||||
type UserKerberosPrincipal struct {
|
||||
PikaTableName string `pika:"user_krb_principals"`
|
||||
|
||||
ID int `db:"id"`
|
||||
CreateTime time.Time `db:"create_time"`
|
||||
ID int `db:"id" pika:"omitempty"`
|
||||
CreateTime time.Time `db:"create_time" pika:"omitempty"`
|
||||
UserID int `db:"user_id"`
|
||||
Principal string `db:"principal"`
|
||||
}
|
||||
|
||||
func (ukp UserKerberosPrincipal) GetID() int64 {
|
||||
return int64(ukp.ID)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue