feat(auth): paseto & sessions
This commit is contained in:
parent
870c2357b7
commit
611619f180
14 changed files with 391 additions and 79 deletions
BIN
dbots
Executable file
BIN
dbots
Executable file
Binary file not shown.
8
go.mod
8
go.mod
|
|
@ -3,6 +3,7 @@ module codeberg.org/nextgo/dbots
|
||||||
go 1.25.8
|
go 1.25.8
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
aidanwoods.dev/go-paseto v1.6.0
|
||||||
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3
|
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3
|
||||||
codeberg.org/ungo/gonsole v0.1.0
|
codeberg.org/ungo/gonsole v0.1.0
|
||||||
github.com/go-chi/chi/v5 v5.2.5
|
github.com/go-chi/chi/v5 v5.2.5
|
||||||
|
|
@ -12,10 +13,13 @@ require (
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
aidanwoods.dev/go-result v0.3.1 // indirect
|
||||||
github.com/ajg/form v1.5.1 // indirect
|
github.com/ajg/form v1.5.1 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
golang.org/x/sync v0.17.0 // indirect
|
golang.org/x/crypto v0.50.0 // indirect
|
||||||
golang.org/x/text v0.29.0 // indirect
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
|
golang.org/x/text v0.36.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
18
go.sum
18
go.sum
|
|
@ -1,5 +1,7 @@
|
||||||
codeberg.org/ungo/env v0.0.0-20260315114019-c4fbd9390cb3 h1:Xn8IiW5uYGajGqYPXU0kS8zXxqRs5E/MTfYjm0O1KrI=
|
aidanwoods.dev/go-paseto v1.6.0 h1:JA/PFk5lVsB/PakQGqnfmik/1tIHjE6F0UoPPoAO/nU=
|
||||||
codeberg.org/ungo/env v0.0.0-20260315114019-c4fbd9390cb3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs=
|
aidanwoods.dev/go-paseto v1.6.0/go.mod h1:LdqkL0Z2mLL0kBWzmHVR1cGFniX+zyOweQmbNKYrDxQ=
|
||||||
|
aidanwoods.dev/go-result v0.3.1 h1:ee98hpohYUVYbI+pa6gUHTyoRerIudgjky/IPSowDXQ=
|
||||||
|
aidanwoods.dev/go-result v0.3.1/go.mod h1:GKnFg8p/BKulVD3wsfULiPhpPmrTWyiTIbz8EWuUqSk=
|
||||||
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 h1:k0NM+1XP3ebvfTvZfiHcyEZc0Drci5oxjZjE7L/xDdE=
|
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3 h1:k0NM+1XP3ebvfTvZfiHcyEZc0Drci5oxjZjE7L/xDdE=
|
||||||
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs=
|
codeberg.org/ungo/env v0.0.0-20260328142946-76f69daf34a3/go.mod h1:pXfrNASG7JyxL30Zof3b1vbpd1dsHePTh3zGfPFgJKs=
|
||||||
codeberg.org/ungo/gonsole v0.1.0 h1:QE/qpSyovejIXzIh29tzmrwgDWfaKUqNTCMZPJEDfvY=
|
codeberg.org/ungo/gonsole v0.1.0 h1:QE/qpSyovejIXzIh29tzmrwgDWfaKUqNTCMZPJEDfvY=
|
||||||
|
|
@ -30,10 +32,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ type Config struct {
|
||||||
Database DatabaseConfig `env:"DATABASE_"`
|
Database DatabaseConfig `env:"DATABASE_"`
|
||||||
Server ServerConfig `env:"SERVER_"`
|
Server ServerConfig `env:"SERVER_"`
|
||||||
Discord DiscordConfig `env:"DISCORD_"`
|
Discord DiscordConfig `env:"DISCORD_"`
|
||||||
|
Auth AuthConfig `env:"AUTH_"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
|
|
@ -27,6 +28,10 @@ type DiscordConfig struct {
|
||||||
RedirectURI string `env:"REDIRECT_URI,default=http://localhost:8080/auth/callback"`
|
RedirectURI string `env:"REDIRECT_URI,default=http://localhost:8080/auth/callback"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
PasetoKey string `env:"PASETO_KEY,required"`
|
||||||
|
}
|
||||||
|
|
||||||
func LoadConfig() Config {
|
func LoadConfig() Config {
|
||||||
var cfg Config
|
var cfg Config
|
||||||
if err := env.Load(&cfg); err != nil {
|
if err := env.Load(&cfg); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,14 @@ type BotCoOwner struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
CreatedAt *time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
Revoked bool `json:"revoked"`
|
||||||
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
|
|
||||||
|
|
@ -14,12 +14,14 @@ type Querier interface {
|
||||||
CountBotsByUsername(ctx context.Context, arg CountBotsByUsernameParams) (int64, error)
|
CountBotsByUsername(ctx context.Context, arg CountBotsByUsernameParams) (int64, error)
|
||||||
CountVotesByBot(ctx context.Context, botID string) (int64, error)
|
CountVotesByBot(ctx context.Context, botID string) (int64, error)
|
||||||
CreateBot(ctx context.Context, arg CreateBotParams) (*Bot, error)
|
CreateBot(ctx context.Context, arg CreateBotParams) (*Bot, error)
|
||||||
|
CreateSession(ctx context.Context, arg CreateSessionParams) (*Session, error)
|
||||||
CreateUser(ctx context.Context, arg CreateUserParams) (*User, error)
|
CreateUser(ctx context.Context, arg CreateUserParams) (*User, error)
|
||||||
CreateVote(ctx context.Context, arg CreateVoteParams) (*Vote, error)
|
CreateVote(ctx context.Context, arg CreateVoteParams) (*Vote, error)
|
||||||
DeleteBot(ctx context.Context, id string) error
|
DeleteBot(ctx context.Context, id string) error
|
||||||
DeleteUser(ctx context.Context, id string) error
|
DeleteUser(ctx context.Context, id string) error
|
||||||
GetBot(ctx context.Context, id string) (*Bot, error)
|
GetBot(ctx context.Context, id string) (*Bot, error)
|
||||||
GetBotCoOwner(ctx context.Context, arg GetBotCoOwnerParams) (*BotCoOwner, error)
|
GetBotCoOwner(ctx context.Context, arg GetBotCoOwnerParams) (*BotCoOwner, error)
|
||||||
|
GetSession(ctx context.Context, id string) (*Session, error)
|
||||||
GetUser(ctx context.Context, id string) (*User, error)
|
GetUser(ctx context.Context, id string) (*User, error)
|
||||||
GetUserByUsername(ctx context.Context, username string) (*User, error)
|
GetUserByUsername(ctx context.Context, username string) (*User, error)
|
||||||
GetVote(ctx context.Context, arg GetVoteParams) (*Vote, error)
|
GetVote(ctx context.Context, arg GetVoteParams) (*Vote, error)
|
||||||
|
|
@ -33,6 +35,8 @@ type Querier interface {
|
||||||
ListVotesByUser(ctx context.Context, userID string) ([]*Vote, error)
|
ListVotesByUser(ctx context.Context, userID string) ([]*Vote, error)
|
||||||
RemoveAllCoOwnersByBot(ctx context.Context, botID string) error
|
RemoveAllCoOwnersByBot(ctx context.Context, botID string) error
|
||||||
RemoveBotCoOwner(ctx context.Context, arg RemoveBotCoOwnerParams) error
|
RemoveBotCoOwner(ctx context.Context, arg RemoveBotCoOwnerParams) error
|
||||||
|
RevokeAllUserSessions(ctx context.Context, userID string) error
|
||||||
|
RevokeSession(ctx context.Context, id string) error
|
||||||
SearchBotsByUsername(ctx context.Context, arg SearchBotsByUsernameParams) ([]*Bot, error)
|
SearchBotsByUsername(ctx context.Context, arg SearchBotsByUsernameParams) ([]*Bot, error)
|
||||||
UpdateBot(ctx context.Context, arg UpdateBotParams) (*Bot, error)
|
UpdateBot(ctx context.Context, arg UpdateBotParams) (*Bot, error)
|
||||||
UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) (*Bot, error)
|
UpdateBotStatus(ctx context.Context, arg UpdateBotStatusParams) (*Bot, error)
|
||||||
|
|
|
||||||
74
internal/db/sessions.sql.go
Normal file
74
internal/db/sessions.sql.go
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.30.0
|
||||||
|
// source: sessions.sql
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const createSession = `-- name: CreateSession :one
|
||||||
|
INSERT INTO sessions (id, user_id, expires_at)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
RETURNING id, user_id, created_at, expires_at, revoked
|
||||||
|
`
|
||||||
|
|
||||||
|
type CreateSessionParams struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (*Session, error) {
|
||||||
|
row := q.db.QueryRow(ctx, createSession, arg.ID, arg.UserID, arg.ExpiresAt)
|
||||||
|
var i Session
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Revoked,
|
||||||
|
)
|
||||||
|
return &i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getSession = `-- name: GetSession :one
|
||||||
|
SELECT id, user_id, created_at, expires_at, revoked FROM sessions
|
||||||
|
WHERE id = $1 AND revoked = false AND expires_at > now()
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetSession(ctx context.Context, id string) (*Session, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getSession, id)
|
||||||
|
var i Session
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.ExpiresAt,
|
||||||
|
&i.Revoked,
|
||||||
|
)
|
||||||
|
return &i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const revokeAllUserSessions = `-- name: RevokeAllUserSessions :exec
|
||||||
|
UPDATE sessions SET revoked = true
|
||||||
|
WHERE user_id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RevokeAllUserSessions(ctx context.Context, userID string) error {
|
||||||
|
_, err := q.db.Exec(ctx, revokeAllUserSessions, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const revokeSession = `-- name: RevokeSession :exec
|
||||||
|
UPDATE sessions SET revoked = true
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) RevokeSession(ctx context.Context, id string) error {
|
||||||
|
_, err := q.db.Exec(ctx, revokeSession, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
17
internal/db/sql/migrations/20260419123958_sessions.sql
Normal file
17
internal/db/sql/migrations/20260419123958_sessions.sql
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
-- +goose Up
|
||||||
|
-- +goose StatementBegin
|
||||||
|
create table sessions (
|
||||||
|
id text primary key, -- jti (random uuid)
|
||||||
|
user_id text not null references users (id) on delete cascade,
|
||||||
|
created_at timestamp with time zone default now(),
|
||||||
|
expires_at timestamp with time zone not null,
|
||||||
|
revoked boolean not null default false
|
||||||
|
);
|
||||||
|
|
||||||
|
create index sessions_user_id_idx on sessions (user_id);
|
||||||
|
-- +goose StatementEnd
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
-- +goose StatementBegin
|
||||||
|
drop table if exists sessions;
|
||||||
|
-- +goose StatementEnd
|
||||||
16
internal/db/sql/queries/sessions.sql
Normal file
16
internal/db/sql/queries/sessions.sql
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
-- name: CreateSession :one
|
||||||
|
INSERT INTO sessions (id, user_id, expires_at)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
RETURNING *;
|
||||||
|
|
||||||
|
-- name: GetSession :one
|
||||||
|
SELECT * FROM sessions
|
||||||
|
WHERE id = $1 AND revoked = false AND expires_at > now();
|
||||||
|
|
||||||
|
-- name: RevokeSession :exec
|
||||||
|
UPDATE sessions SET revoked = true
|
||||||
|
WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: RevokeAllUserSessions :exec
|
||||||
|
UPDATE sessions SET revoked = true
|
||||||
|
WHERE user_id = $1;
|
||||||
|
|
@ -6,31 +6,63 @@ import (
|
||||||
|
|
||||||
"codeberg.org/nextgo/dbots/internal/db"
|
"codeberg.org/nextgo/dbots/internal/db"
|
||||||
"codeberg.org/nextgo/dbots/internal/errorutil"
|
"codeberg.org/nextgo/dbots/internal/errorutil"
|
||||||
|
"codeberg.org/nextgo/dbots/internal/token"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
type contextKey string
|
||||||
|
|
||||||
const userKey contextKey = "user"
|
// UserContextKey is exported so service routers can read the user from context.
|
||||||
|
var UserContextKey contextKey = "user"
|
||||||
|
|
||||||
|
// AuthMiddleware reads the PASETO session cookie, verifies it, checks it
|
||||||
|
// hasn't been revoked in the DB, then sets the *db.User on the context.
|
||||||
|
// Does NOT block unauthenticated requests — use AuthGuardMiddleware for that.
|
||||||
|
func AuthMiddleware(q *db.Queries, pasetoKeyHex string) func(http.Handler) http.Handler {
|
||||||
|
key, err := token.KeyFromHex(pasetoKeyHex)
|
||||||
|
if err != nil {
|
||||||
|
panic("middleware: invalid PASETO key: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
// AuthMiddleware is a middleware to set the user as context value.
|
|
||||||
// this middleware does not prevents the user from accessing the route
|
|
||||||
// if not authorized.
|
|
||||||
func AuthMiddleware(q *db.Queries) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// ctx := context.WithValue(r.Context(), userKey, user) // mocked
|
c, err := r.Cookie("session")
|
||||||
next.ServeHTTP(w, r)
|
if err != nil {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := token.VerifyToken(key, c.Value)
|
||||||
|
if err != nil {
|
||||||
|
// Expired or tampered — clear the cookie and continue as anonymous.
|
||||||
|
http.SetCookie(w, &http.Cookie{Name: "session", MaxAge: -1, Path: "/"})
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the session hasn't been revoked server-side.
|
||||||
|
if _, err := q.GetSession(r.Context(), claims.JTI); err != nil {
|
||||||
|
http.SetCookie(w, &http.Cookie{Name: "session", MaxAge: -1, Path: "/"})
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := q.GetUser(r.Context(), claims.UserID)
|
||||||
|
if err != nil {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), UserContextKey, user)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthGuardMiddleware is a middleware that prevents the user
|
// AuthGuardMiddleware blocks requests where no authenticated user was set.
|
||||||
// from accessing the route if they are NOT authorized.
|
|
||||||
func AuthGuardMiddleware(next http.Handler) http.Handler {
|
func AuthGuardMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, ok := r.Context().Value(userKey).(*db.User)
|
if _, ok := r.Context().Value(UserContextKey).(*db.User); !ok {
|
||||||
if !ok {
|
|
||||||
render.Render(w, r, errorutil.ErrUnauthorized)
|
render.Render(w, r, errorutil.ErrUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -38,10 +70,8 @@ func AuthGuardMiddleware(next http.Handler) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUser returns the authenticated user from context.
|
||||||
|
// Only safe to call inside a route guarded by AuthGuardMiddleware.
|
||||||
func GetUser(ctx context.Context) *db.User {
|
func GetUser(ctx context.Context) *db.User {
|
||||||
user, ok := ctx.Value(userKey).(*db.User)
|
return ctx.Value(UserContextKey).(*db.User)
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return user
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ func NewServer(queries *db.Queries, config config.Config) *Server {
|
||||||
router.Use(middleware.Recoverer)
|
router.Use(middleware.Recoverer)
|
||||||
router.Use(middleware.RequestID)
|
router.Use(middleware.RequestID)
|
||||||
router.Use(middleware.RealIP)
|
router.Use(middleware.RealIP)
|
||||||
router.Use(customMiddlewares.AuthMiddleware(queries)) // todo: use this middleware only when necessary
|
router.Use(customMiddlewares.AuthMiddleware(queries, config.Auth.PasetoKey)) // todo: use this middleware only when necessary
|
||||||
// i am using this globally cus it uses mocked data lol
|
// i am using this globally cus it uses mocked data lol
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
|
|
@ -49,7 +49,7 @@ func (s *Server) Register() {
|
||||||
s.config.Discord.RedirectURI,
|
s.config.Discord.RedirectURI,
|
||||||
)
|
)
|
||||||
|
|
||||||
authRouter := auth.NewRouter(s.queries, discordClient)
|
authRouter := auth.NewRouter(s.queries, discordClient, s.config.Auth.PasetoKey)
|
||||||
botRouter := bot.NewRouter(s.queries, discordClient)
|
botRouter := bot.NewRouter(s.queries, discordClient)
|
||||||
adminRouter := admin.NewRouter(s.queries)
|
adminRouter := admin.NewRouter(s.queries)
|
||||||
|
|
||||||
|
|
|
||||||
67
internal/token/token.go
Normal file
67
internal/token/token.go
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"aidanwoods.dev/go-paseto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const TokenDuration = 7 * 24 * time.Hour
|
||||||
|
|
||||||
|
const (
|
||||||
|
claimUserID = "uid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyFromHex loads a V4 symmetric key from a 64-char hex string (AUTH_PASETO_KEY).
|
||||||
|
func KeyFromHex(h string) (paseto.V4SymmetricKey, error) {
|
||||||
|
b, err := hex.DecodeString(h)
|
||||||
|
if err != nil {
|
||||||
|
return paseto.V4SymmetricKey{}, err
|
||||||
|
}
|
||||||
|
return paseto.V4SymmetricKeyFromBytes(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueToken creates a PASETO v4 local (encrypted) token carrying the
|
||||||
|
// user ID and a jti that maps to a row in the sessions table.
|
||||||
|
func IssueToken(key paseto.V4SymmetricKey, userID, jti string) (string, error) {
|
||||||
|
tok := paseto.NewToken()
|
||||||
|
tok.SetIssuedAt(time.Now())
|
||||||
|
tok.SetNotBefore(time.Now())
|
||||||
|
tok.SetExpiration(time.Now().Add(TokenDuration))
|
||||||
|
tok.SetJti(jti)
|
||||||
|
if err := tok.Set(claimUserID, userID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return tok.V4Encrypt(key, nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claims holds the verified payload extracted from a token.
|
||||||
|
type Claims struct {
|
||||||
|
UserID string
|
||||||
|
JTI string
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyToken decrypts and validates a PASETO v4 local token.
|
||||||
|
func VerifyToken(key paseto.V4SymmetricKey, raw string) (Claims, error) {
|
||||||
|
parser := paseto.NewParser()
|
||||||
|
parser.AddRule(paseto.NotExpired())
|
||||||
|
parser.AddRule(paseto.ValidAt(time.Now()))
|
||||||
|
|
||||||
|
tok, err := parser.ParseV4Local(key, raw, nil)
|
||||||
|
if err != nil {
|
||||||
|
return Claims{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
if err := tok.Get(claimUserID, &userID); err != nil {
|
||||||
|
return Claims{}, errors.New("missing uid claim")
|
||||||
|
}
|
||||||
|
jti, err := tok.GetJti()
|
||||||
|
if err != nil {
|
||||||
|
return Claims{}, errors.New("missing jti claim")
|
||||||
|
}
|
||||||
|
|
||||||
|
return Claims{UserID: userID, JTI: jti}, nil
|
||||||
|
}
|
||||||
|
|
@ -3,16 +3,12 @@ package auth
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/hex"
|
||||||
"errors"
|
|
||||||
|
|
||||||
"codeberg.org/nextgo/dbots/internal/db"
|
"codeberg.org/nextgo/dbots/internal/db"
|
||||||
"codeberg.org/nextgo/dbots/internal/discord"
|
"codeberg.org/nextgo/dbots/internal/discord"
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// todo: api keysssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss
|
|
||||||
// or sessions????????????
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
q *db.Queries
|
q *db.Queries
|
||||||
client *discord.Client
|
client *discord.Client
|
||||||
|
|
@ -22,44 +18,42 @@ func NewService(q *db.Queries, client *discord.Client) *Service {
|
||||||
return &Service{q: q, client: client}
|
return &Service{q: q, client: client}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateState produces a random OAuth state parameter.
|
// Callback exchanges the OAuth code for a Discord access token,
|
||||||
|
// fetches the Discord user, and upserts them in the database.
|
||||||
|
// It returns the db.User and the raw Discord access token
|
||||||
|
// (needed so the caller can store it in the session if desired).
|
||||||
|
func (s *Service) Callback(ctx context.Context, code string) (*db.User, *discord.TokenResponse, error) {
|
||||||
|
tok, err := s.client.ExchangeCode(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dUser, err := s.client.GetCurrentUser(ctx, tok.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.q.GetUser(ctx, dUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
// First login — create the user.
|
||||||
|
user, err = s.q.CreateUser(ctx, db.CreateUserParams{
|
||||||
|
ID: dUser.ID,
|
||||||
|
Username: dUser.Username,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateState returns a cryptographically random hex string for
|
||||||
|
// the OAuth2 state parameter.
|
||||||
func GenerateState() (string, error) {
|
func GenerateState() (string, error) {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
if _, err := rand.Read(b); err != nil {
|
if _, err := rand.Read(b); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return base64.URLEncoding.EncodeToString(b), nil
|
return hex.EncodeToString(b), nil
|
||||||
}
|
|
||||||
|
|
||||||
// Callback handles the OAuth callback: exchanges the code, fetches the user,
|
|
||||||
// and upserts them in the database. Returns the local db.User.
|
|
||||||
func (s *Service) Callback(ctx context.Context, code string) (*db.User, error) {
|
|
||||||
token, err := s.client.ExchangeCode(ctx, code)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dUser, err := s.client.GetCurrentUser(ctx, token.AccessToken)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := s.q.UpdateUser(ctx, db.UpdateUserParams{
|
|
||||||
ID: dUser.ID,
|
|
||||||
Username: &dUser.Username,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
user, err = s.q.CreateUser(ctx, db.CreateUserParams{
|
|
||||||
ID: dUser.ID,
|
|
||||||
Username: dUser.Username,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,23 +2,32 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"codeberg.org/nextgo/dbots/internal/db"
|
"codeberg.org/nextgo/dbots/internal/db"
|
||||||
"codeberg.org/nextgo/dbots/internal/discord"
|
"codeberg.org/nextgo/dbots/internal/discord"
|
||||||
"codeberg.org/nextgo/dbots/internal/errorutil"
|
"codeberg.org/nextgo/dbots/internal/errorutil"
|
||||||
|
"codeberg.org/nextgo/dbots/internal/middleware"
|
||||||
|
"codeberg.org/nextgo/dbots/internal/token"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/render"
|
"github.com/go-chi/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const cookieName = "session"
|
||||||
|
|
||||||
type Router struct {
|
type Router struct {
|
||||||
auth *Service
|
auth *Service
|
||||||
router chi.Router
|
router chi.Router
|
||||||
|
pasetoKey string // hex-encoded AUTH_PASETO_KEY
|
||||||
|
queries *db.Queries
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouter(q *db.Queries, client *discord.Client) *Router {
|
func NewRouter(q *db.Queries, client *discord.Client, pasetoKey string) *Router {
|
||||||
return &Router{
|
return &Router{
|
||||||
auth: NewService(q, client),
|
auth: NewService(q, client),
|
||||||
router: chi.NewRouter(),
|
router: chi.NewRouter(),
|
||||||
|
pasetoKey: pasetoKey,
|
||||||
|
queries: q,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -26,12 +35,17 @@ func (r *Router) Routes() http.Handler {
|
||||||
r.router.Get("/login", r.login)
|
r.router.Get("/login", r.login)
|
||||||
r.router.Get("/callback", r.callback)
|
r.router.Get("/callback", r.callback)
|
||||||
r.router.Post("/logout", r.logout)
|
r.router.Post("/logout", r.logout)
|
||||||
r.router.Get("/me", r.me)
|
r.router.With(middleware.AuthGuardMiddleware).Get("/me", r.me)
|
||||||
return r.router
|
return r.router
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) me(w http.ResponseWriter, req *http.Request) {
|
func (r *Router) me(w http.ResponseWriter, req *http.Request) {
|
||||||
|
user := middleware.GetUser(req.Context())
|
||||||
|
if user == nil {
|
||||||
|
render.Render(w, req, errorutil.ErrUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
render.JSON(w, req, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) login(w http.ResponseWriter, req *http.Request) {
|
func (r *Router) login(w http.ResponseWriter, req *http.Request) {
|
||||||
|
|
@ -40,29 +54,102 @@ func (r *Router) login(w http.ResponseWriter, req *http.Request) {
|
||||||
render.Render(w, req, errorutil.ErrInternal(err))
|
render.Render(w, req, errorutil.ErrInternal(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// todo: store state in a short-lived cookie or session before redirecting
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "oauth_state",
|
||||||
|
Value: state,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: 300,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
|
||||||
http.Redirect(w, req, r.auth.client.AuthURL(state), http.StatusFound)
|
http.Redirect(w, req, r.auth.client.AuthURL(state), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) callback(w http.ResponseWriter, req *http.Request) {
|
func (r *Router) callback(w http.ResponseWriter, req *http.Request) {
|
||||||
// todo: validate state matches what was stored
|
stateCookie, err := req.Cookie("oauth_state")
|
||||||
|
if err != nil || stateCookie.Value != req.URL.Query().Get("state") {
|
||||||
|
render.Render(w, req, errorutil.ErrUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.SetCookie(w, &http.Cookie{Name: "oauth_state", MaxAge: -1, Path: "/"})
|
||||||
|
|
||||||
code := req.URL.Query().Get("code")
|
code := req.URL.Query().Get("code")
|
||||||
if code == "" {
|
if code == "" {
|
||||||
render.Render(w, req, errorutil.ErrInvalidRequest(nil))
|
render.Render(w, req, errorutil.ErrInvalidRequest(nil))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := r.auth.Callback(req.Context(), code)
|
user, _, err := r.auth.Callback(req.Context(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Render(w, req, errorutil.ErrInternal(err))
|
render.Render(w, req, errorutil.ErrInternal(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: create a session, set a cookie, then redirect to "/"
|
// Generate a session ID (jti) and persist it to the DB for server-side revocation.
|
||||||
render.JSON(w, req, user)
|
jti, err := GenerateState() // reuses the same crypto/rand helper
|
||||||
|
if err != nil {
|
||||||
|
render.Render(w, req, errorutil.ErrInternal(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = r.queries.CreateSession(req.Context(), db.CreateSessionParams{
|
||||||
|
ID: jti,
|
||||||
|
UserID: user.ID,
|
||||||
|
ExpiresAt: time.Now().Add(token.TokenDuration),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
render.Render(w, req, errorutil.ErrInternal(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := token.KeyFromHex(r.pasetoKey)
|
||||||
|
if err != nil {
|
||||||
|
render.Render(w, req, errorutil.ErrInternal(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := token.IssueToken(key, user.ID, jti)
|
||||||
|
if err != nil {
|
||||||
|
render.Render(w, req, errorutil.ErrInternal(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: cookieName,
|
||||||
|
Value: raw,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
MaxAge: int(token.TokenDuration.Seconds()),
|
||||||
|
Path: "/",
|
||||||
|
// Secure: true, // enable in production (HTTPS)
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, req, "/", http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// POST /auth/logout — revoke the session server-side and clear the cookie.
|
||||||
func (r *Router) logout(w http.ResponseWriter, req *http.Request) {
|
func (r *Router) logout(w http.ResponseWriter, req *http.Request) {
|
||||||
// todo: delete session
|
c, err := req.Cookie(cookieName)
|
||||||
|
if err != nil {
|
||||||
|
// Already logged out.
|
||||||
|
render.NoContent(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := token.KeyFromHex(r.pasetoKey)
|
||||||
|
if err == nil {
|
||||||
|
if claims, err := token.VerifyToken(key, c.Value); err == nil {
|
||||||
|
// Best-effort: ignore DB errors, the cookie will be cleared anyway.
|
||||||
|
_ = r.queries.RevokeSession(req.Context(), claims.JTI)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: cookieName,
|
||||||
|
MaxAge: -1,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
render.NoContent(w, req)
|
render.NoContent(w, req)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue