feat(auth): paseto & sessions

This commit is contained in:
Elisiei Yehorov 2026-04-19 15:07:08 +02:00
parent 870c2357b7
commit 611619f180
Signed by: elisiei
GPG key ID: BA1D158DCE3DF089
14 changed files with 391 additions and 79 deletions

BIN
dbots Executable file

Binary file not shown.

8
go.mod
View file

@ -3,6 +3,7 @@ module codeberg.org/nextgo/dbots
go 1.25.8 go 1.25.8
require ( require (
aidanwoods.dev/go-paseto v1.6.0
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3
codeberg.org/ungo/gonsole v0.1.0 codeberg.org/ungo/gonsole v0.1.0
github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/chi/v5 v5.2.5
@ -12,10 +13,13 @@ require (
) )
require ( require (
aidanwoods.dev/go-result v0.3.1 // indirect
github.com/ajg/form v1.5.1 // indirect github.com/ajg/form v1.5.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.17.0 // indirect golang.org/x/crypto v0.50.0 // indirect
golang.org/x/text v0.29.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.36.0 // indirect
) )

18
go.sum
View file

@ -1,5 +1,7 @@
codeberg.org/ungo/env v0.0.0-20260315114019-c4fbd9390cb3 h1:Xn8IiW5uYGajGqYPXU0kS8zXxqRs5E/MTfYjm0O1KrI= aidanwoods.dev/go-paseto v1.6.0 h1:JA/PFk5lVsB/PakQGqnfmik/1tIHjE6F0UoPPoAO/nU=
codeberg.org/ungo/env v0.0.0-20260315114019-c4fbd9390cb3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs= aidanwoods.dev/go-paseto v1.6.0/go.mod h1:LdqkL0Z2mLL0kBWzmHVR1cGFniX+zyOweQmbNKYrDxQ=
aidanwoods.dev/go-result v0.3.1 h1:ee98hpohYUVYbI+pa6gUHTyoRerIudgjky/IPSowDXQ=
aidanwoods.dev/go-result v0.3.1/go.mod h1:GKnFg8p/BKulVD3wsfULiPhpPmrTWyiTIbz8EWuUqSk=
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 h1:k0NM+1XP3ebvfTvZfiHcyEZc0Drci5oxjZjE7L/xDdE= codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 h1:k0NM+1XP3ebvfTvZfiHcyEZc0Drci5oxjZjE7L/xDdE=
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs= codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs=
codeberg.org/ungo/gonsole v0.1.0 h1:QE/qpSyovejIXzIh29tzmrwgDWfaKUqNTCMZPJEDfvY= codeberg.org/ungo/gonsole v0.1.0 h1:QE/qpSyovejIXzIh29tzmrwgDWfaKUqNTCMZPJEDfvY=
@ -30,10 +32,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -9,6 +9,7 @@ type Config struct {
Database DatabaseConfig `env:"DATABASE_"` Database DatabaseConfig `env:"DATABASE_"`
Server ServerConfig `env:"SERVER_"` Server ServerConfig `env:"SERVER_"`
Discord DiscordConfig `env:"DISCORD_"` Discord DiscordConfig `env:"DISCORD_"`
Auth AuthConfig `env:"AUTH_"`
} }
type DatabaseConfig struct { type DatabaseConfig struct {
@ -27,6 +28,10 @@ type DiscordConfig struct {
RedirectURI string `env:"REDIRECT_URI,default=http://localhost:8080/auth/callback"` RedirectURI string `env:"REDIRECT_URI,default=http://localhost:8080/auth/callback"`
} }
type AuthConfig struct {
PasetoKey string `env:"PASETO_KEY,required"`
}
func LoadConfig() Config { func LoadConfig() Config {
var cfg Config var cfg Config
if err := env.Load(&cfg); err != nil { if err := env.Load(&cfg); err != nil {

View file

@ -31,6 +31,14 @@ type BotCoOwner struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
} }
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
CreatedAt *time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
Revoked bool `json:"revoked"`
}
type User struct { type User struct {
ID string `json:"id"` ID string `json:"id"`
Username string `json:"username"` Username string `json:"username"`

View file

@ -14,12 +14,14 @@ type Querier interface {
CountBotsByUsername(ctx context.Context, arg CountBotsByUsernameParams) (int64, error) CountBotsByUsername(ctx context.Context, arg CountBotsByUsernameParams) (int64, error)
CountVotesByBot(ctx context.Context, botID string) (int64, error) CountVotesByBot(ctx context.Context, botID string) (int64, error)
CreateBot(ctx context.Context, arg CreateBotParams) (*Bot, error) CreateBot(ctx context.Context, arg CreateBotParams) (*Bot, error)
CreateSession(ctx context.Context, arg CreateSessionParams) (*Session, error)
CreateUser(ctx context.Context, arg CreateUserParams) (*User, error) CreateUser(ctx context.Context, arg CreateUserParams) (*User, error)
CreateVote(ctx context.Context, arg CreateVoteParams) (*Vote, error) CreateVote(ctx context.Context, arg CreateVoteParams) (*Vote, error)
DeleteBot(ctx context.Context, id string) error DeleteBot(ctx context.Context, id string) error
DeleteUser(ctx context.Context, id string) error DeleteUser(ctx context.Context, id string) error
GetBot(ctx context.Context, id string) (*Bot, error) GetBot(ctx context.Context, id string) (*Bot, error)
GetBotCoOwner(ctx context.Context, arg GetBotCoOwnerParams) (*BotCoOwner, error) GetBotCoOwner(ctx context.Context, arg GetBotCoOwnerParams) (*BotCoOwner, error)
GetSession(ctx context.Context, id string) (*Session, error)
GetUser(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, id string) (*User, error)
GetUserByUsername(ctx context.Context, username string) (*User, error) GetUserByUsername(ctx context.Context, username string) (*User, error)
GetVote(ctx context.Context, arg GetVoteParams) (*Vote, error) GetVote(ctx context.Context, arg GetVoteParams) (*Vote, error)
@ -33,6 +35,8 @@ type Querier interface {
ListVotesByUser(ctx context.Context, userID string) ([]*Vote, error) ListVotesByUser(ctx context.Context, userID string) ([]*Vote, error)
RemoveAllCoOwnersByBot(ctx context.Context, botID string) error RemoveAllCoOwnersByBot(ctx context.Context, botID string) error
RemoveBotCoOwner(ctx context.Context, arg RemoveBotCoOwnerParams) error RemoveBotCoOwner(ctx context.Context, arg RemoveBotCoOwnerParams) error
RevokeAllUserSessions(ctx context.Context, userID string) error
RevokeSession(ctx context.Context, id string) error
SearchBotsByUsername(ctx context.Context, arg SearchBotsByUsernameParams) ([]*Bot, error) SearchBotsByUsername(ctx context.Context, arg SearchBotsByUsernameParams) ([]*Bot, error)
UpdateBot(ctx context.Context, arg UpdateBotParams) (*Bot, error) UpdateBot(ctx context.Context, arg UpdateBotParams) (*Bot, error)
UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) (*Bot, error) UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) (*Bot, error)

View file

@ -0,0 +1,74 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: sessions.sql
package db
import (
"context"
"time"
)
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (id, user_id, expires_at)
VALUES ($1, $2, $3)
RETURNING id, user_id, created_at, expires_at, revoked
`
type CreateSessionParams struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (*Session, error) {
row := q.db.QueryRow(ctx, createSession, arg.ID, arg.UserID, arg.ExpiresAt)
var i Session
err := row.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.ExpiresAt,
&i.Revoked,
)
return &i, err
}
const getSession = `-- name: GetSession :one
SELECT id, user_id, created_at, expires_at, revoked FROM sessions
WHERE id = $1 AND revoked = false AND expires_at > now()
`
func (q *Queries) GetSession(ctx context.Context, id string) (*Session, error) {
row := q.db.QueryRow(ctx, getSession, id)
var i Session
err := row.Scan(
&i.ID,
&i.UserID,
&i.CreatedAt,
&i.ExpiresAt,
&i.Revoked,
)
return &i, err
}
const revokeAllUserSessions = `-- name: RevokeAllUserSessions :exec
UPDATE sessions SET revoked = true
WHERE user_id = $1
`
func (q *Queries) RevokeAllUserSessions(ctx context.Context, userID string) error {
_, err := q.db.Exec(ctx, revokeAllUserSessions, userID)
return err
}
const revokeSession = `-- name: RevokeSession :exec
UPDATE sessions SET revoked = true
WHERE id = $1
`
func (q *Queries) RevokeSession(ctx context.Context, id string) error {
_, err := q.db.Exec(ctx, revokeSession, id)
return err
}

View file

@ -0,0 +1,17 @@
-- +goose Up
-- +goose StatementBegin
create table sessions (
id text primary key, -- jti (random uuid)
user_id text not null references users (id) on delete cascade,
created_at timestamp with time zone default now(),
expires_at timestamp with time zone not null,
revoked boolean not null default false
);
create index sessions_user_id_idx on sessions (user_id);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
drop table if exists sessions;
-- +goose StatementEnd

View file

@ -0,0 +1,16 @@
-- name: CreateSession :one
INSERT INTO sessions (id, user_id, expires_at)
VALUES ($1, $2, $3)
RETURNING *;
-- name: GetSession :one
SELECT * FROM sessions
WHERE id = $1 AND revoked = false AND expires_at > now();
-- name: RevokeSession :exec
UPDATE sessions SET revoked = true
WHERE id = $1;
-- name: RevokeAllUserSessions :exec
UPDATE sessions SET revoked = true
WHERE user_id = $1;

View file

@ -6,31 +6,63 @@ import (
"codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/db"
"codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/errorutil"
"codeberg.org/nextgo/dbots/internal/token"
"github.com/go-chi/render" "github.com/go-chi/render"
) )
type contextKey string type contextKey string
const userKey contextKey = "user" // UserContextKey is exported so service routers can read the user from context.
var UserContextKey contextKey = "user"
// AuthMiddleware reads the PASETO session cookie, verifies it, checks it
// hasn't been revoked in the DB, then sets the *db.User on the context.
// Does NOT block unauthenticated requests — use AuthGuardMiddleware for that.
func AuthMiddleware(q *db.Queries, pasetoKeyHex string) func(http.Handler) http.Handler {
key, err := token.KeyFromHex(pasetoKeyHex)
if err != nil {
panic("middleware: invalid PASETO key: " + err.Error())
}
// AuthMiddleware is a middleware to set the user as context value.
// this middleware does not prevents the user from accessing the route
// if not authorized.
func AuthMiddleware(q *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ctx := context.WithValue(r.Context(), userKey, user) // mocked c, err := r.Cookie("session")
if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return
}
claims, err := token.VerifyToken(key, c.Value)
if err != nil {
// Expired or tampered — clear the cookie and continue as anonymous.
http.SetCookie(w, &http.Cookie{Name: "session", MaxAge: -1, Path: "/"})
next.ServeHTTP(w, r)
return
}
// Check the session hasn't been revoked server-side.
if _, err := q.GetSession(r.Context(), claims.JTI); err != nil {
http.SetCookie(w, &http.Cookie{Name: "session", MaxAge: -1, Path: "/"})
next.ServeHTTP(w, r)
return
}
user, err := q.GetUser(r.Context(), claims.UserID)
if err != nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), UserContextKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
} }
// AuthGuardMiddleware is a middleware that prevents the user // AuthGuardMiddleware blocks requests where no authenticated user was set.
// from accessing the route if they are NOT authorized.
func AuthGuardMiddleware(next http.Handler) http.Handler { func AuthGuardMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Context().Value(userKey).(*db.User) if _, ok := r.Context().Value(UserContextKey).(*db.User); !ok {
if !ok {
render.Render(w, r, errorutil.ErrUnauthorized) render.Render(w, r, errorutil.ErrUnauthorized)
return return
} }
@ -38,10 +70,8 @@ func AuthGuardMiddleware(next http.Handler) http.Handler {
}) })
} }
// GetUser returns the authenticated user from context.
// Only safe to call inside a route guarded by AuthGuardMiddleware.
func GetUser(ctx context.Context) *db.User { func GetUser(ctx context.Context) *db.User {
user, ok := ctx.Value(userKey).(*db.User) return ctx.Value(UserContextKey).(*db.User)
if !ok {
return nil
}
return user
} }

View file

@ -32,7 +32,7 @@ func NewServer(queries *db.Queries, config config.Config) *Server {
router.Use(middleware.Recoverer) router.Use(middleware.Recoverer)
router.Use(middleware.RequestID) router.Use(middleware.RequestID)
router.Use(middleware.RealIP) router.Use(middleware.RealIP)
router.Use(customMiddlewares.AuthMiddleware(queries)) // todo: use this middleware only when necessary router.Use(customMiddlewares.AuthMiddleware(queries, config.Auth.PasetoKey)) // todo: use this middleware only when necessary
// i am using this globally cus it uses mocked data lol // i am using this globally cus it uses mocked data lol
return &Server{ return &Server{
@ -49,7 +49,7 @@ func (s *Server) Register() {
s.config.Discord.RedirectURI, s.config.Discord.RedirectURI,
) )
authRouter := auth.NewRouter(s.queries, discordClient) authRouter := auth.NewRouter(s.queries, discordClient, s.config.Auth.PasetoKey)
botRouter := bot.NewRouter(s.queries, discordClient) botRouter := bot.NewRouter(s.queries, discordClient)
adminRouter := admin.NewRouter(s.queries) adminRouter := admin.NewRouter(s.queries)

67
internal/token/token.go Normal file
View file

@ -0,0 +1,67 @@
package token
import (
"encoding/hex"
"errors"
"time"
"aidanwoods.dev/go-paseto"
)
const TokenDuration = 7 * 24 * time.Hour
const (
claimUserID = "uid"
)
// KeyFromHex loads a V4 symmetric key from a 64-char hex string (AUTH_PASETO_KEY).
func KeyFromHex(h string) (paseto.V4SymmetricKey, error) {
b, err := hex.DecodeString(h)
if err != nil {
return paseto.V4SymmetricKey{}, err
}
return paseto.V4SymmetricKeyFromBytes(b)
}
// IssueToken creates a PASETO v4 local (encrypted) token carrying the
// user ID and a jti that maps to a row in the sessions table.
func IssueToken(key paseto.V4SymmetricKey, userID, jti string) (string, error) {
tok := paseto.NewToken()
tok.SetIssuedAt(time.Now())
tok.SetNotBefore(time.Now())
tok.SetExpiration(time.Now().Add(TokenDuration))
tok.SetJti(jti)
if err := tok.Set(claimUserID, userID); err != nil {
return "", err
}
return tok.V4Encrypt(key, nil), nil
}
// Claims holds the verified payload extracted from a token.
type Claims struct {
UserID string
JTI string
}
// VerifyToken decrypts and validates a PASETO v4 local token.
func VerifyToken(key paseto.V4SymmetricKey, raw string) (Claims, error) {
parser := paseto.NewParser()
parser.AddRule(paseto.NotExpired())
parser.AddRule(paseto.ValidAt(time.Now()))
tok, err := parser.ParseV4Local(key, raw, nil)
if err != nil {
return Claims{}, err
}
var userID string
if err := tok.Get(claimUserID, &userID); err != nil {
return Claims{}, errors.New("missing uid claim")
}
jti, err := tok.GetJti()
if err != nil {
return Claims{}, errors.New("missing jti claim")
}
return Claims{UserID: userID, JTI: jti}, nil
}

View file

@ -3,16 +3,12 @@ package auth
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/hex"
"errors"
"codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/db"
"codeberg.org/nextgo/dbots/internal/discord" "codeberg.org/nextgo/dbots/internal/discord"
"github.com/jackc/pgx/v5"
) )
// todo: api keysssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss
// or sessions????????????
type Service struct { type Service struct {
q *db.Queries q *db.Queries
client *discord.Client client *discord.Client
@ -22,44 +18,42 @@ func NewService(q *db.Queries, client *discord.Client) *Service {
return &Service{q: q, client: client} return &Service{q: q, client: client}
} }
// GenerateState produces a random OAuth state parameter. // Callback exchanges the OAuth code for a Discord access token,
func GenerateState() (string, error) { // fetches the Discord user, and upserts them in the database.
b := make([]byte, 16) // It returns the db.User and the raw Discord access token
if _, err := rand.Read(b); err != nil { // (needed so the caller can store it in the session if desired).
return "", err func (s *Service) Callback(ctx context.Context, code string) (*db.User, *discord.TokenResponse, error) {
} tok, err := s.client.ExchangeCode(ctx, code)
return base64.URLEncoding.EncodeToString(b), nil
}
// Callback handles the OAuth callback: exchanges the code, fetches the user,
// and upserts them in the database. Returns the local db.User.
func (s *Service) Callback(ctx context.Context, code string) (*db.User, error) {
token, err := s.client.ExchangeCode(ctx, code)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
dUser, err := s.client.GetCurrentUser(ctx, token.AccessToken) dUser, err := s.client.GetCurrentUser(ctx, tok.AccessToken)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
user, err := s.q.UpdateUser(ctx, db.UpdateUserParams{ user, err := s.q.GetUser(ctx, dUser.ID)
ID: dUser.ID,
Username: &dUser.Username,
})
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { // First login — create the user.
user, err = s.q.CreateUser(ctx, db.CreateUserParams{ user, err = s.q.CreateUser(ctx, db.CreateUserParams{
ID: dUser.ID, ID: dUser.ID,
Username: dUser.Username, Username: dUser.Username,
}) })
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
} }
return nil, err
}
return user, nil return user, tok, nil
}
// GenerateState returns a cryptographically random hex string for
// the OAuth2 state parameter.
func GenerateState() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
} }

View file

@ -2,23 +2,32 @@ package auth
import ( import (
"net/http" "net/http"
"time"
"codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/db"
"codeberg.org/nextgo/dbots/internal/discord" "codeberg.org/nextgo/dbots/internal/discord"
"codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/errorutil"
"codeberg.org/nextgo/dbots/internal/middleware"
"codeberg.org/nextgo/dbots/internal/token"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/render" "github.com/go-chi/render"
) )
const cookieName = "session"
type Router struct { type Router struct {
auth *Service auth *Service
router chi.Router router chi.Router
pasetoKey string // hex-encoded AUTH_PASETO_KEY
queries *db.Queries
} }
func NewRouter(q *db.Queries, client *discord.Client) *Router { func NewRouter(q *db.Queries, client *discord.Client, pasetoKey string) *Router {
return &Router{ return &Router{
auth: NewService(q, client), auth: NewService(q, client),
router: chi.NewRouter(), router: chi.NewRouter(),
pasetoKey: pasetoKey,
queries: q,
} }
} }
@ -26,12 +35,17 @@ func (r *Router) Routes() http.Handler {
r.router.Get("/login", r.login) r.router.Get("/login", r.login)
r.router.Get("/callback", r.callback) r.router.Get("/callback", r.callback)
r.router.Post("/logout", r.logout) r.router.Post("/logout", r.logout)
r.router.Get("/me", r.me) r.router.With(middleware.AuthGuardMiddleware).Get("/me", r.me)
return r.router return r.router
} }
func (r *Router) me(w http.ResponseWriter, req *http.Request) { func (r *Router) me(w http.ResponseWriter, req *http.Request) {
user := middleware.GetUser(req.Context())
if user == nil {
render.Render(w, req, errorutil.ErrUnauthorized)
return
}
render.JSON(w, req, user)
} }
func (r *Router) login(w http.ResponseWriter, req *http.Request) { func (r *Router) login(w http.ResponseWriter, req *http.Request) {
@ -40,29 +54,102 @@ func (r *Router) login(w http.ResponseWriter, req *http.Request) {
render.Render(w, req, errorutil.ErrInternal(err)) render.Render(w, req, errorutil.ErrInternal(err))
return return
} }
// todo: store state in a short-lived cookie or session before redirecting
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 300,
Path: "/",
})
http.Redirect(w, req, r.auth.client.AuthURL(state), http.StatusFound) http.Redirect(w, req, r.auth.client.AuthURL(state), http.StatusFound)
} }
func (r *Router) callback(w http.ResponseWriter, req *http.Request) { func (r *Router) callback(w http.ResponseWriter, req *http.Request) {
// todo: validate state matches what was stored stateCookie, err := req.Cookie("oauth_state")
if err != nil || stateCookie.Value != req.URL.Query().Get("state") {
render.Render(w, req, errorutil.ErrUnauthorized)
return
}
http.SetCookie(w, &http.Cookie{Name: "oauth_state", MaxAge: -1, Path: "/"})
code := req.URL.Query().Get("code") code := req.URL.Query().Get("code")
if code == "" { if code == "" {
render.Render(w, req, errorutil.ErrInvalidRequest(nil)) render.Render(w, req, errorutil.ErrInvalidRequest(nil))
return return
} }
user, err := r.auth.Callback(req.Context(), code) user, _, err := r.auth.Callback(req.Context(), code)
if err != nil { if err != nil {
render.Render(w, req, errorutil.ErrInternal(err)) render.Render(w, req, errorutil.ErrInternal(err))
return return
} }
// todo: create a session, set a cookie, then redirect to "/" // Generate a session ID (jti) and persist it to the DB for server-side revocation.
render.JSON(w, req, user) jti, err := GenerateState() // reuses the same crypto/rand helper
if err != nil {
render.Render(w, req, errorutil.ErrInternal(err))
return
}
_, err = r.queries.CreateSession(req.Context(), db.CreateSessionParams{
ID: jti,
UserID: user.ID,
ExpiresAt: time.Now().Add(token.TokenDuration),
})
if err != nil {
render.Render(w, req, errorutil.ErrInternal(err))
return
}
key, err := token.KeyFromHex(r.pasetoKey)
if err != nil {
render.Render(w, req, errorutil.ErrInternal(err))
return
}
raw, err := token.IssueToken(key, user.ID, jti)
if err != nil {
render.Render(w, req, errorutil.ErrInternal(err))
return
}
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: raw,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
MaxAge: int(token.TokenDuration.Seconds()),
Path: "/",
// Secure: true, // enable in production (HTTPS)
})
http.Redirect(w, req, "/", http.StatusFound)
} }
// POST /auth/logout — revoke the session server-side and clear the cookie.
func (r *Router) logout(w http.ResponseWriter, req *http.Request) { func (r *Router) logout(w http.ResponseWriter, req *http.Request) {
// todo: delete session c, err := req.Cookie(cookieName)
if err != nil {
// Already logged out.
render.NoContent(w, req)
return
}
key, err := token.KeyFromHex(r.pasetoKey)
if err == nil {
if claims, err := token.VerifyToken(key, c.Value); err == nil {
// Best-effort: ignore DB errors, the cookie will be cleared anyway.
_ = r.queries.RevokeSession(req.Context(), claims.JTI)
}
}
http.SetCookie(w, &http.Cookie{
Name: cookieName,
MaxAge: -1,
Path: "/",
})
render.NoContent(w, req) render.NoContent(w, req)
} }