Implement session endpoints

This commit is contained in:
Mustafa Gezen 2024-12-25 13:40:58 +01:00
parent 214de1a86f
commit 76a0a67c8d
11 changed files with 291 additions and 29 deletions

View file

@ -31,6 +31,9 @@
- "core::0133::request-id-field"
# We don't require update mask support
- "core::0134::request-mask-required"
# We have custom methods that only return a single resource
- "core::0136::response-message-name"
- "core::0136::request-message-name"
- included_paths:
- "third_party/**/*.proto"
- "vendor/**/*.proto"

17
kojicompat/logout.go Normal file
View file

@ -0,0 +1,17 @@
package kojicompat
import (
"context"
"github.com/rocky-linux/peridot/v2/xmlrpc"
"google.golang.org/protobuf/types/known/emptypb"
)
func (s *Server) Logout(ctx context.Context, _ *xmlrpc.Nothing) (*xmlrpc.Nothing, error) {
_, err := s.client.DestroySession(ctx, &emptypb.Empty{})
if err != nil {
return nil, err
}
return &xmlrpc.Nothing{}, nil
}

View file

@ -1,10 +1,11 @@
CREATE TABLE users
(
id BIGSERIAL PRIMARY KEY,
create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMPTZ,
username VARCHAR(255) NOT NULL,
password_hash VARCHAR(255)
username VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255),
type INT NOT NULL
);
CREATE TABLE user_krb_principals
@ -30,11 +31,12 @@ CREATE TABLE sessions
create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id BIGINT REFERENCES users (id) NOT NULL,
expire_time TIMESTAMPTZ NOT NULL,
last_renew_time TIMESTAMPTZ NOT NULL,
last_renew_time TIMESTAMPTZ,
renew_count INT NOT NULL DEFAULT 0,
key VARCHAR(255) NOT NULL,
auth_type INT NOT NULL,
host_ip VARCHAR(255) NOT NULL,
call_num INT NOT NULL,
closed BOOLEAN NOT NULL DEFAULT FALSE
closed BOOLEAN NOT NULL DEFAULT FALSE,
krb_principal VARCHAR(255)
)

View file

@ -2,17 +2,122 @@ package server
import (
"context"
"database/sql"
"fmt"
"log/slog"
"math/rand"
"net"
"time"
peridotpb "github.com/rocky-linux/peridot/v2/peridot/v2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
)
type Session struct {
PikaTableName string `pika:"sessions"`
PikaDefaultOrderBy string `pika:"-create_time"`
ID int64 `db:"id" pika:"omitempty"`
CreateTime time.Time `db:"create_time" pika:"omitempty"`
UserID int64 `db:"user_id"`
ExpireTime time.Time `db:"expire_time"`
LastRenewTime sql.Null[time.Time] `db:"last_renew_time" pika:"omitempty"`
RenewCount int `db:"renew_count"`
Key string `db:"key"`
AuthType peridotpb.Session_AuthType `db:"auth_type"`
HostIP string `db:"host_ip"`
CallNum int `db:"call_num"`
Closed bool `db:"closed"`
KrbPrincipal sql.Null[string] `db:"krb_principal"`
}
func (s Session) GetID() int64 {
return int64(s.ID)
}
func (s Session) ToPB(u *User) *peridotpb.Session {
user := u.ToPB()
return &peridotpb.Session{
User: user,
AuthType: s.AuthType,
KerberosPrincipal: NullString(s.KrbPrincipal),
}
}
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func generateKey(userID int64) string {
b := make([]byte, 19)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return fmt.Sprintf("%d-%s", userID, string(b))
}
func (s *Server) NewSession(ctx context.Context, req *peridotpb.NewSessionRequest) (*peridotpb.NewSessionResponse, error) {
user := getUserFromContext(ctx)
slog.Info("NewSession called by user", "user", user.Username)
key := generateKey(user.GetID())
remotePeer, _ := peer.FromContext(ctx)
ip, _, err := net.SplitHostPort(remotePeer.Addr.String())
if err != nil {
slog.Error("Failed to split host and port", "error", err)
return nil, status.Error(codes.Internal, "Failed to split host and port")
}
var authType peridotpb.Session_AuthType
if _, ok := ctx.Value(ContextKerberosPrincipalKey{}).(*UserKerberosPrincipal); ok {
authType = peridotpb.Session_AUTH_TYPE_GSSAPI
} else {
authType = peridotpb.Session_AUTH_TYPE_UNSPECIFIED
}
session := &Session{
UserID: user.GetID(),
Key: key,
HostIP: ip,
ExpireTime: time.Now().Add(time.Hour),
CallNum: 1,
AuthType: authType,
}
err = Q[Session](s.db).Createx(ctx, session)
if err != nil {
slog.Error("Failed to create session", "error", err)
return nil, status.Error(codes.Internal, "Failed to create session")
}
return &peridotpb.NewSessionResponse{
SessionId: 1,
SessionKey: "session-key",
SessionId: int32(session.ID),
SessionKey: key,
}, nil
}
func (s *Server) CurrentSession(ctx context.Context, req *emptypb.Empty) (*peridotpb.Session, error) {
session, _ := ctx.Value(ContextSessionKey{}).(*Session)
if session == nil {
return nil, status.Error(codes.FailedPrecondition, "No session")
}
user := getUserFromContext(ctx)
return session.ToPB(user), nil
}
func (s *Server) DestroySession(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) {
session, _ := ctx.Value(ContextSessionKey{}).(*Session)
if session == nil {
return nil, status.Error(codes.FailedPrecondition, "No session")
}
session.Closed = true
err := Q[Session](s.db).Ux(ctx, session)
if err != nil {
slog.Error("Failed to close session", "error", err)
return nil, status.Error(codes.Internal, "Failed to close session")
}
return &emptypb.Empty{}, nil
}

View file

@ -173,7 +173,10 @@ func DefaultServeMuxOptions(s ...GRPCServer) []runtime.ServeMuxOption {
runtime.WithIncomingHeaderMatcher(func(s string) (string, bool) {
switch strings.ToLower(s) {
case "authorization",
"cookie":
"cookie",
"peridot-session-id",
"peridot-session-key",
"peridot-session-call-num":
return s, true
}

View file

@ -7,6 +7,7 @@ import (
"log/slog"
"strings"
peridotpb "github.com/rocky-linux/peridot/v2/peridot/v2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
@ -19,6 +20,8 @@ import (
"gopkg.in/jcmturner/gokrb5.v7/types"
)
type ContextKerberosPrincipalKey struct{}
// From https://github.com/jcmturner/gokrb5/blob/855dbc707a37a21467aef6c0245fcf3328dc39ed/spnego/http.go#L196
const (
// spnegoNegTokenRespKRBAcceptCompleted - The response on successful authentication always has this header. Capturing as const so we don't have marshaling and encoding overhead.
@ -112,8 +115,9 @@ func (s *Server) kerberosInterceptor(next InterceptorFunc) InterceptorFunc {
if user == nil {
user = &User{
Username: id.UserName(),
Type: peridotpb.User_TYPE_USER,
}
err = Q[User](s.db).Create(ctx, user)
err = Q[User](s.db).Createx(ctx, user)
if err != nil {
slog.Error("Error creating user in DB", "error", err)
return nil, status.Error(codes.Internal, "Error creating user in DB")
@ -132,7 +136,7 @@ func (s *Server) kerberosInterceptor(next InterceptorFunc) InterceptorFunc {
Principal: principal,
UserID: user.ID,
}
err = Q[UserKerberosPrincipal](s.db).Create(ctx, krbPrincipal)
err = Q[UserKerberosPrincipal](s.db).Createx(ctx, krbPrincipal)
if err != nil {
slog.Error("Error creating kerberos principal in DB", "error", err)
return nil, status.Error(codes.Internal, "Error creating kerberos principal in DB")
@ -140,6 +144,7 @@ func (s *Server) kerberosInterceptor(next InterceptorFunc) InterceptorFunc {
}
ctx = context.WithValue(ctx, ContextUserKey{}, user)
ctx = context.WithValue(ctx, ContextKerberosPrincipalKey{}, krbPrincipal)
}
return next(ctx, req, usi, handler)

View file

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
orderedmap "github.com/wk8/go-ordered-map/v2"
"go.ciq.dev/pika"
)
@ -16,6 +15,7 @@ type Pika[T any] interface {
Fx(keyval ...any) Pika[T]
Dx(ctx context.Context, x any) error
Transactionx(ctx context.Context) (Pika[T], error)
Createx(ctx context.Context, x *T) error
}
type innerDB[T any] struct {
@ -27,15 +27,6 @@ type nameInterface interface {
GetID() int64
}
func NewDBArgs(keyval ...any) *orderedmap.OrderedMap[string, any] {
args := pika.NewArgs()
for i := 0; i < len(keyval); i += 2 {
args.Set(keyval[i].(string), keyval[i+1])
}
return args
}
func Q[T any](db DB) Pika[T] {
return &innerDB[T]{pika.PSQLQuery[T](pika.NewPostgreSQLFromDB(db.GetDB())), db}
}
@ -109,8 +100,7 @@ func (inner *innerDB[T]) Ux(ctx context.Context, x any) error {
return qs.Update(ctx, y)
}
func (inner *innerDB[T]) Createx(x *T) error {
ctx := context.TODO()
func (inner *innerDB[T]) Createx(ctx context.Context, x *T) error {
ts := pika.NewPostgreSQLFromDB(inner.db.GetDB())
err := ts.Begin(ctx)
if err != nil {
@ -153,3 +143,16 @@ func (inner *innerDB[T]) Createx(x *T) error {
return nil
}
// Disable the default Create, Update, Delete methods
func (inner *innerDB[T]) Create(ctx context.Context, value *T, options ...pika.CreateOption) error {
return errors.New("Create method is disabled")
}
func (inner *innerDB[T]) Update(ctx context.Context, value *T) error {
return errors.New("Update method is disabled")
}
func (inner *innerDB[T]) Delete(ctx context.Context) error {
return errors.New("Delete method is disabled")
}

23
server/protobuf.go Normal file
View file

@ -0,0 +1,23 @@
package server
import (
"database/sql"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func NullTimestampPb(t sql.Null[time.Time]) *timestamppb.Timestamp {
if t.Valid {
return timestamppb.New(t.V)
}
return nil
}
func NullString(s sql.Null[string]) *wrapperspb.StringValue {
if s.Valid {
return wrapperspb.String(s.V)
}
return nil
}

View file

@ -70,6 +70,7 @@ func (s *Server) interceptor(ctx context.Context, req interface{}, usi *grpc.Una
n := EndInterceptor
n = s.authEnforcerInterceptor(n)
n = s.kerberosInterceptor(n)
n = s.sessionAuthInterceptor(n)
return n(ctx, req, usi, handler)
}

78
server/session_auth.go Normal file
View file

@ -0,0 +1,78 @@
package server
import (
"context"
"log/slog"
"strconv"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type ContextSessionKey struct{}
func (s *Server) sessionAuthInterceptor(next InterceptorFunc) InterceptorFunc {
return func(ctx context.Context, req any, usi *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
// If already authenticated, skip
if getUserFromContext(ctx) != nil {
return next(ctx, req, usi, handler)
}
md, _ := metadata.FromIncomingContext(ctx)
sessionID, _ := md["peridot-session-id"]
sessionKey, _ := md["peridot-session-key"]
callNum, _ := md["peridot-session-call-num"]
if len(sessionID) != 1 || len(sessionKey) != 1 || len(callNum) != 1 {
return next(ctx, req, usi, handler)
}
// Check if the session is valid
session, err := Q[Session](s.db).Fx("id", sessionID[0], "key", sessionKey[0]).GetOrNil(ctx)
if err != nil {
return nil, status.Error(codes.Internal, "Error checking session")
}
if session == nil {
return nil, status.Error(codes.Unauthenticated, "Negotiate")
}
// Verify not closed
if session.Closed {
return nil, status.Error(codes.Unauthenticated, "Negotiate")
}
// Verify expiry
if session.ExpireTime.Before(time.Now()) {
return nil, status.Error(codes.Unauthenticated, "Negotiate")
}
// Verify call number
callNumInt, err := strconv.Atoi(callNum[0])
if err != nil {
return nil, status.Error(codes.Internal, "Invalid call number")
}
if session.CallNum != callNumInt {
return nil, status.Error(codes.Internal, "Invalid call number")
}
// Increase call number
session.CallNum++
err = Q[Session](s.db).Ux(ctx, session)
if err != nil {
slog.Error("Error updating session", "error", err)
return nil, status.Error(codes.Internal, "Error updating session")
}
// Set user in context
user, err := Q[User](s.db).Fx("id", session.UserID).GetOrNil(ctx)
if err != nil {
return nil, status.Error(codes.Internal, "Error getting user")
}
ctx = context.WithValue(ctx, ContextUserKey{}, user)
ctx = context.WithValue(ctx, ContextSessionKey{}, session)
return next(ctx, req, usi, handler)
}
}

View file

@ -3,24 +3,46 @@ package server
import (
"database/sql"
"time"
peridotpb "github.com/rocky-linux/peridot/v2/peridot/v2"
"google.golang.org/protobuf/types/known/timestamppb"
)
type User struct {
PikaTableName string `pika:"users"`
PikaDefaultOrderBy string `pika:"-create_time"`
ID int `db:"id"`
CreateTime time.Time `db:"create_time"`
UpdateTime sql.Null[time.Time] `db:"update_time"`
ID int `db:"id" pika:"omitempty"`
CreateTime time.Time `db:"create_time" pika:"omitempty"`
UpdateTime sql.Null[time.Time] `db:"update_time" pika:"omitempty"`
Username string `db:"username"`
PasswordHash sql.Null[string] `db:"password_hash"`
PasswordHash sql.Null[string] `db:"password_hash" pika:"omitempty"`
Type peridotpb.User_Type `db:"type"`
}
func (u User) GetID() int64 {
return int64(u.ID)
}
func (u User) ToPB() *peridotpb.User {
return &peridotpb.User{
Id: int64(u.ID),
CreateTime: timestamppb.New(u.CreateTime),
UpdateTime: NullTimestampPb(u.UpdateTime),
Username: u.Username,
Type: u.Type,
}
}
type UserKerberosPrincipal struct {
PikaTableName string `pika:"user_krb_principals"`
ID int `db:"id"`
CreateTime time.Time `db:"create_time"`
ID int `db:"id" pika:"omitempty"`
CreateTime time.Time `db:"create_time" pika:"omitempty"`
UserID int `db:"user_id"`
Principal string `db:"principal"`
}
func (ukp UserKerberosPrincipal) GetID() int64 {
return int64(ukp.ID)
}