diff --git a/dbots b/dbots new file mode 100755 index 0000000..879cdf3 Binary files /dev/null and b/dbots differ diff --git a/go.mod b/go.mod index caedcfc..7420eca 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module codeberg.org/nextgo/dbots go 1.25.8 require ( + aidanwoods.dev/go-paseto v1.6.0 codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 codeberg.org/ungo/gonsole v0.1.0 github.com/go-chi/chi/v5 v5.2.5 @@ -12,10 +13,13 @@ require ( ) require ( + aidanwoods.dev/go-result v0.3.1 // indirect github.com/ajg/form v1.5.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - golang.org/x/sync v0.17.0 // indirect - golang.org/x/text v0.29.0 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect ) diff --git a/go.sum b/go.sum index df647ea..1ab0484 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ -codeberg.org/ungo/env v0.0.0-20260315114019-c4fbd9390cb3 h1:Xn8IiW5uYGajGqYPXU0kS8zXxqRs5E/MTfYjm0O1KrI= -codeberg.org/ungo/env v0.0.0-20260315114019-c4fbd9390cb3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs= +aidanwoods.dev/go-paseto v1.6.0 h1:JA/PFk5lVsB/PakQGqnfmik/1tIHjE6F0UoPPoAO/nU= +aidanwoods.dev/go-paseto v1.6.0/go.mod h1:LdqkL0Z2mLL0kBWzmHVR1cGFniX+zyOweQmbNKYrDxQ= +aidanwoods.dev/go-result v0.3.1 h1:ee98hpohYUVYbI+pa6gUHTyoRerIudgjky/IPSowDXQ= +aidanwoods.dev/go-result v0.3.1/go.mod h1:GKnFg8p/BKulVD3wsfULiPhpPmrTWyiTIbz8EWuUqSk= codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 h1:k0NM+1XP3ebvfTvZfiHcyEZc0Drci5oxjZjE7L/xDdE= codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs= codeberg.org/ungo/gonsole v0.1.0 h1:QE/qpSyovejIXzIh29tzmrwgDWfaKUqNTCMZPJEDfvY= @@ -30,10 +32,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= -golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/config.go b/internal/config/config.go index b3a617e..b08820c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,7 @@ type Config struct { Database DatabaseConfig `env:"DATABASE_"` Server ServerConfig `env:"SERVER_"` Discord DiscordConfig `env:"DISCORD_"` + Auth AuthConfig `env:"AUTH_"` } type DatabaseConfig struct { @@ -27,6 +28,10 @@ type DiscordConfig struct { RedirectURI string `env:"REDIRECT_URI,default=http://localhost:8080/auth/callback"` } +type AuthConfig struct { + PasetoKey string `env:"PASETO_KEY,required"` +} + func LoadConfig() Config { var cfg Config if err := env.Load(&cfg); err != nil { diff --git a/internal/db/models.go b/internal/db/models.go index cde1bf1..d1cd4ed 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -31,6 +31,14 @@ type BotCoOwner struct { UserID string `json:"user_id"` } +type Session struct { + ID string `json:"id"` + UserID string `json:"user_id"` + CreatedAt *time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + Revoked bool `json:"revoked"` +} + type User struct { ID string `json:"id"` Username string `json:"username"` diff --git a/internal/db/querier.go b/internal/db/querier.go index 30fbce6..a0fa311 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -14,12 +14,14 @@ type Querier interface { CountBotsByUsername(ctx context.Context, arg CountBotsByUsernameParams) (int64, error) CountVotesByBot(ctx context.Context, botID string) (int64, error) CreateBot(ctx context.Context, arg CreateBotParams) (*Bot, error) + CreateSession(ctx context.Context, arg CreateSessionParams) (*Session, error) CreateUser(ctx context.Context, arg CreateUserParams) (*User, error) CreateVote(ctx context.Context, arg CreateVoteParams) (*Vote, error) DeleteBot(ctx context.Context, id string) error DeleteUser(ctx context.Context, id string) error GetBot(ctx context.Context, id string) (*Bot, error) GetBotCoOwner(ctx context.Context, arg GetBotCoOwnerParams) (*BotCoOwner, error) + GetSession(ctx context.Context, id string) (*Session, error) GetUser(ctx context.Context, id string) (*User, error) GetUserByUsername(ctx context.Context, username string) (*User, error) GetVote(ctx context.Context, arg GetVoteParams) (*Vote, error) @@ -33,6 +35,8 @@ type Querier interface { ListVotesByUser(ctx context.Context, userID string) ([]*Vote, error) RemoveAllCoOwnersByBot(ctx context.Context, botID string) error RemoveBotCoOwner(ctx context.Context, arg RemoveBotCoOwnerParams) error + RevokeAllUserSessions(ctx context.Context, userID string) error + RevokeSession(ctx context.Context, id string) error SearchBotsByUsername(ctx context.Context, arg SearchBotsByUsernameParams) ([]*Bot, error) UpdateBot(ctx context.Context, arg UpdateBotParams) (*Bot, error) UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) (*Bot, error) diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go new file mode 100644 index 0000000..b89de8f --- /dev/null +++ b/internal/db/sessions.sql.go @@ -0,0 +1,74 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: sessions.sql + +package db + +import ( + "context" + "time" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO sessions (id, user_id, expires_at) +VALUES ($1, $2, $3) +RETURNING id, user_id, created_at, expires_at, revoked +` + +type CreateSessionParams struct { + ID string `json:"id"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (*Session, error) { + row := q.db.QueryRow(ctx, createSession, arg.ID, arg.UserID, arg.ExpiresAt) + var i Session + err := row.Scan( + &i.ID, + &i.UserID, + &i.CreatedAt, + &i.ExpiresAt, + &i.Revoked, + ) + return &i, err +} + +const getSession = `-- name: GetSession :one +SELECT id, user_id, created_at, expires_at, revoked FROM sessions +WHERE id = $1 AND revoked = false AND expires_at > now() +` + +func (q *Queries) GetSession(ctx context.Context, id string) (*Session, error) { + row := q.db.QueryRow(ctx, getSession, id) + var i Session + err := row.Scan( + &i.ID, + &i.UserID, + &i.CreatedAt, + &i.ExpiresAt, + &i.Revoked, + ) + return &i, err +} + +const revokeAllUserSessions = `-- name: RevokeAllUserSessions :exec +UPDATE sessions SET revoked = true +WHERE user_id = $1 +` + +func (q *Queries) RevokeAllUserSessions(ctx context.Context, userID string) error { + _, err := q.db.Exec(ctx, revokeAllUserSessions, userID) + return err +} + +const revokeSession = `-- name: RevokeSession :exec +UPDATE sessions SET revoked = true +WHERE id = $1 +` + +func (q *Queries) RevokeSession(ctx context.Context, id string) error { + _, err := q.db.Exec(ctx, revokeSession, id) + return err +} diff --git a/internal/db/sql/migrations/20260419123958_sessions.sql b/internal/db/sql/migrations/20260419123958_sessions.sql new file mode 100644 index 0000000..60aec61 --- /dev/null +++ b/internal/db/sql/migrations/20260419123958_sessions.sql @@ -0,0 +1,17 @@ +-- +goose Up +-- +goose StatementBegin +create table sessions ( + id text primary key, -- jti (random uuid) + user_id text not null references users (id) on delete cascade, + created_at timestamp with time zone default now(), + expires_at timestamp with time zone not null, + revoked boolean not null default false +); + +create index sessions_user_id_idx on sessions (user_id); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +drop table if exists sessions; +-- +goose StatementEnd diff --git a/internal/db/sql/queries/sessions.sql b/internal/db/sql/queries/sessions.sql new file mode 100644 index 0000000..6731750 --- /dev/null +++ b/internal/db/sql/queries/sessions.sql @@ -0,0 +1,16 @@ +-- name: CreateSession :one +INSERT INTO sessions (id, user_id, expires_at) +VALUES ($1, $2, $3) +RETURNING *; + +-- name: GetSession :one +SELECT * FROM sessions +WHERE id = $1 AND revoked = false AND expires_at > now(); + +-- name: RevokeSession :exec +UPDATE sessions SET revoked = true +WHERE id = $1; + +-- name: RevokeAllUserSessions :exec +UPDATE sessions SET revoked = true +WHERE user_id = $1; diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 5f2adab..7ba177b 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -6,31 +6,63 @@ import ( "codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/errorutil" + "codeberg.org/nextgo/dbots/internal/token" "github.com/go-chi/render" ) type contextKey string -const userKey contextKey = "user" +// UserContextKey is exported so service routers can read the user from context. +var UserContextKey contextKey = "user" + +// AuthMiddleware reads the PASETO session cookie, verifies it, checks it +// hasn't been revoked in the DB, then sets the *db.User on the context. +// Does NOT block unauthenticated requests — use AuthGuardMiddleware for that. +func AuthMiddleware(q *db.Queries, pasetoKeyHex string) func(http.Handler) http.Handler { + key, err := token.KeyFromHex(pasetoKeyHex) + if err != nil { + panic("middleware: invalid PASETO key: " + err.Error()) + } -// AuthMiddleware is a middleware to set the user as context value. -// this middleware does not prevents the user from accessing the route -// if not authorized. -func AuthMiddleware(q *db.Queries) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // ctx := context.WithValue(r.Context(), userKey, user) // mocked - next.ServeHTTP(w, r) + c, err := r.Cookie("session") + if err != nil { + next.ServeHTTP(w, r) + return + } + + claims, err := token.VerifyToken(key, c.Value) + if err != nil { + // Expired or tampered — clear the cookie and continue as anonymous. + http.SetCookie(w, &http.Cookie{Name: "session", MaxAge: -1, Path: "/"}) + next.ServeHTTP(w, r) + return + } + + // Check the session hasn't been revoked server-side. + if _, err := q.GetSession(r.Context(), claims.JTI); err != nil { + http.SetCookie(w, &http.Cookie{Name: "session", MaxAge: -1, Path: "/"}) + next.ServeHTTP(w, r) + return + } + + user, err := q.GetUser(r.Context(), claims.UserID) + if err != nil { + next.ServeHTTP(w, r) + return + } + + ctx := context.WithValue(r.Context(), UserContextKey, user) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } -// AuthGuardMiddleware is a middleware that prevents the user -// from accessing the route if they are NOT authorized. +// AuthGuardMiddleware blocks requests where no authenticated user was set. func AuthGuardMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, ok := r.Context().Value(userKey).(*db.User) - if !ok { + if _, ok := r.Context().Value(UserContextKey).(*db.User); !ok { render.Render(w, r, errorutil.ErrUnauthorized) return } @@ -38,10 +70,8 @@ func AuthGuardMiddleware(next http.Handler) http.Handler { }) } +// GetUser returns the authenticated user from context. +// Only safe to call inside a route guarded by AuthGuardMiddleware. func GetUser(ctx context.Context) *db.User { - user, ok := ctx.Value(userKey).(*db.User) - if !ok { - return nil - } - return user + return ctx.Value(UserContextKey).(*db.User) } diff --git a/internal/server/server.go b/internal/server/server.go index 977f77b..18c6052 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -32,7 +32,7 @@ func NewServer(queries *db.Queries, config config.Config) *Server { router.Use(middleware.Recoverer) router.Use(middleware.RequestID) router.Use(middleware.RealIP) - router.Use(customMiddlewares.AuthMiddleware(queries)) // todo: use this middleware only when necessary + router.Use(customMiddlewares.AuthMiddleware(queries, config.Auth.PasetoKey)) // todo: use this middleware only when necessary // i am using this globally cus it uses mocked data lol return &Server{ @@ -49,7 +49,7 @@ func (s *Server) Register() { s.config.Discord.RedirectURI, ) - authRouter := auth.NewRouter(s.queries, discordClient) + authRouter := auth.NewRouter(s.queries, discordClient, s.config.Auth.PasetoKey) botRouter := bot.NewRouter(s.queries, discordClient) adminRouter := admin.NewRouter(s.queries) diff --git a/internal/token/token.go b/internal/token/token.go new file mode 100644 index 0000000..8552b22 --- /dev/null +++ b/internal/token/token.go @@ -0,0 +1,67 @@ +package token + +import ( + "encoding/hex" + "errors" + "time" + + "aidanwoods.dev/go-paseto" +) + +const TokenDuration = 7 * 24 * time.Hour + +const ( + claimUserID = "uid" +) + +// KeyFromHex loads a V4 symmetric key from a 64-char hex string (AUTH_PASETO_KEY). +func KeyFromHex(h string) (paseto.V4SymmetricKey, error) { + b, err := hex.DecodeString(h) + if err != nil { + return paseto.V4SymmetricKey{}, err + } + return paseto.V4SymmetricKeyFromBytes(b) +} + +// IssueToken creates a PASETO v4 local (encrypted) token carrying the +// user ID and a jti that maps to a row in the sessions table. +func IssueToken(key paseto.V4SymmetricKey, userID, jti string) (string, error) { + tok := paseto.NewToken() + tok.SetIssuedAt(time.Now()) + tok.SetNotBefore(time.Now()) + tok.SetExpiration(time.Now().Add(TokenDuration)) + tok.SetJti(jti) + if err := tok.Set(claimUserID, userID); err != nil { + return "", err + } + return tok.V4Encrypt(key, nil), nil +} + +// Claims holds the verified payload extracted from a token. +type Claims struct { + UserID string + JTI string +} + +// VerifyToken decrypts and validates a PASETO v4 local token. +func VerifyToken(key paseto.V4SymmetricKey, raw string) (Claims, error) { + parser := paseto.NewParser() + parser.AddRule(paseto.NotExpired()) + parser.AddRule(paseto.ValidAt(time.Now())) + + tok, err := parser.ParseV4Local(key, raw, nil) + if err != nil { + return Claims{}, err + } + + var userID string + if err := tok.Get(claimUserID, &userID); err != nil { + return Claims{}, errors.New("missing uid claim") + } + jti, err := tok.GetJti() + if err != nil { + return Claims{}, errors.New("missing jti claim") + } + + return Claims{UserID: userID, JTI: jti}, nil +} diff --git a/services/auth/auth.go b/services/auth/auth.go index 76f9a15..b0d79f5 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -3,16 +3,12 @@ package auth import ( "context" "crypto/rand" - "encoding/base64" - "errors" + "encoding/hex" "codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/discord" - "github.com/jackc/pgx/v5" ) -// todo: api keysssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss -// or sessions???????????? type Service struct { q *db.Queries client *discord.Client @@ -22,44 +18,42 @@ func NewService(q *db.Queries, client *discord.Client) *Service { return &Service{q: q, client: client} } -// GenerateState produces a random OAuth state parameter. +// Callback exchanges the OAuth code for a Discord access token, +// fetches the Discord user, and upserts them in the database. +// It returns the db.User and the raw Discord access token +// (needed so the caller can store it in the session if desired). +func (s *Service) Callback(ctx context.Context, code string) (*db.User, *discord.TokenResponse, error) { + tok, err := s.client.ExchangeCode(ctx, code) + if err != nil { + return nil, nil, err + } + + dUser, err := s.client.GetCurrentUser(ctx, tok.AccessToken) + if err != nil { + return nil, nil, err + } + + user, err := s.q.GetUser(ctx, dUser.ID) + if err != nil { + // First login — create the user. + user, err = s.q.CreateUser(ctx, db.CreateUserParams{ + ID: dUser.ID, + Username: dUser.Username, + }) + if err != nil { + return nil, nil, err + } + } + + return user, tok, nil +} + +// GenerateState returns a cryptographically random hex string for +// the OAuth2 state parameter. func GenerateState() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { return "", err } - return base64.URLEncoding.EncodeToString(b), nil -} - -// Callback handles the OAuth callback: exchanges the code, fetches the user, -// and upserts them in the database. Returns the local db.User. -func (s *Service) Callback(ctx context.Context, code string) (*db.User, error) { - token, err := s.client.ExchangeCode(ctx, code) - if err != nil { - return nil, err - } - - dUser, err := s.client.GetCurrentUser(ctx, token.AccessToken) - if err != nil { - return nil, err - } - - user, err := s.q.UpdateUser(ctx, db.UpdateUserParams{ - ID: dUser.ID, - Username: &dUser.Username, - }) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - user, err = s.q.CreateUser(ctx, db.CreateUserParams{ - ID: dUser.ID, - Username: dUser.Username, - }) - if err != nil { - return nil, err - } - } - return nil, err - } - - return user, nil + return hex.EncodeToString(b), nil } diff --git a/services/auth/router.go b/services/auth/router.go index 647509f..3afc38e 100644 --- a/services/auth/router.go +++ b/services/auth/router.go @@ -2,23 +2,32 @@ package auth import ( "net/http" + "time" "codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/discord" "codeberg.org/nextgo/dbots/internal/errorutil" + "codeberg.org/nextgo/dbots/internal/middleware" + "codeberg.org/nextgo/dbots/internal/token" "github.com/go-chi/chi/v5" "github.com/go-chi/render" ) +const cookieName = "session" + type Router struct { - auth *Service - router chi.Router + auth *Service + router chi.Router + pasetoKey string // hex-encoded AUTH_PASETO_KEY + queries *db.Queries } -func NewRouter(q *db.Queries, client *discord.Client) *Router { +func NewRouter(q *db.Queries, client *discord.Client, pasetoKey string) *Router { return &Router{ - auth: NewService(q, client), - router: chi.NewRouter(), + auth: NewService(q, client), + router: chi.NewRouter(), + pasetoKey: pasetoKey, + queries: q, } } @@ -26,12 +35,17 @@ func (r *Router) Routes() http.Handler { r.router.Get("/login", r.login) r.router.Get("/callback", r.callback) r.router.Post("/logout", r.logout) - r.router.Get("/me", r.me) + r.router.With(middleware.AuthGuardMiddleware).Get("/me", r.me) return r.router } func (r *Router) me(w http.ResponseWriter, req *http.Request) { - + user := middleware.GetUser(req.Context()) + if user == nil { + render.Render(w, req, errorutil.ErrUnauthorized) + return + } + render.JSON(w, req, user) } func (r *Router) login(w http.ResponseWriter, req *http.Request) { @@ -40,29 +54,102 @@ func (r *Router) login(w http.ResponseWriter, req *http.Request) { render.Render(w, req, errorutil.ErrInternal(err)) return } - // todo: store state in a short-lived cookie or session before redirecting + + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: state, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 300, + Path: "/", + }) + http.Redirect(w, req, r.auth.client.AuthURL(state), http.StatusFound) } func (r *Router) callback(w http.ResponseWriter, req *http.Request) { - // todo: validate state matches what was stored + stateCookie, err := req.Cookie("oauth_state") + if err != nil || stateCookie.Value != req.URL.Query().Get("state") { + render.Render(w, req, errorutil.ErrUnauthorized) + return + } + http.SetCookie(w, &http.Cookie{Name: "oauth_state", MaxAge: -1, Path: "/"}) + code := req.URL.Query().Get("code") if code == "" { render.Render(w, req, errorutil.ErrInvalidRequest(nil)) return } - user, err := r.auth.Callback(req.Context(), code) + user, _, err := r.auth.Callback(req.Context(), code) if err != nil { render.Render(w, req, errorutil.ErrInternal(err)) return } - // todo: create a session, set a cookie, then redirect to "/" - render.JSON(w, req, user) + // Generate a session ID (jti) and persist it to the DB for server-side revocation. + jti, err := GenerateState() // reuses the same crypto/rand helper + if err != nil { + render.Render(w, req, errorutil.ErrInternal(err)) + return + } + + _, err = r.queries.CreateSession(req.Context(), db.CreateSessionParams{ + ID: jti, + UserID: user.ID, + ExpiresAt: time.Now().Add(token.TokenDuration), + }) + if err != nil { + render.Render(w, req, errorutil.ErrInternal(err)) + return + } + + key, err := token.KeyFromHex(r.pasetoKey) + if err != nil { + render.Render(w, req, errorutil.ErrInternal(err)) + return + } + + raw, err := token.IssueToken(key, user.ID, jti) + if err != nil { + render.Render(w, req, errorutil.ErrInternal(err)) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: cookieName, + Value: raw, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + MaxAge: int(token.TokenDuration.Seconds()), + Path: "/", + // Secure: true, // enable in production (HTTPS) + }) + + http.Redirect(w, req, "/", http.StatusFound) } +// POST /auth/logout — revoke the session server-side and clear the cookie. func (r *Router) logout(w http.ResponseWriter, req *http.Request) { - // todo: delete session + c, err := req.Cookie(cookieName) + if err != nil { + // Already logged out. + render.NoContent(w, req) + return + } + + key, err := token.KeyFromHex(r.pasetoKey) + if err == nil { + if claims, err := token.VerifyToken(key, c.Value); err == nil { + // Best-effort: ignore DB errors, the cookie will be cleared anyway. + _ = r.queries.RevokeSession(req.Context(), claims.JTI) + } + } + + http.SetCookie(w, &http.Cookie{ + Name: cookieName, + MaxAge: -1, + Path: "/", + }) render.NoContent(w, req) }