refactor: bot context

This commit is contained in:
Elisiei Yehorov 2026-04-17 23:46:03 +02:00
parent 9df1d0de56
commit 9ef5e33b82
Signed by: elisiei
GPG key ID: BA1D158DCE3DF089
8 changed files with 131 additions and 41 deletions

5
.gitignore vendored
View file

@ -26,8 +26,11 @@ go.work.sum
# env file # env file
.env .env
.direnv .direnv/
# pg
.pgdata/
.pgsocket/
# Editor/IDE # Editor/IDE
# .idea/ # .idea/
# .vscode/ # .vscode/

View file

@ -14,6 +14,7 @@
"x86_64-darwin" "x86_64-darwin"
"aarch64-darwin" "aarch64-darwin"
]; ];
forEachSupportedSystem = forEachSupportedSystem =
f: f:
inputs.nixpkgs.lib.genAttrs supportedSystems ( inputs.nixpkgs.lib.genAttrs supportedSystems (
@ -39,7 +40,33 @@
go go
sqlc sqlc
goose goose
postgresql
(pkgs.writeShellScriptBin "pg-start" ''
export PGDATA=${"$PWD"}/.pgdata
export PGHOST=${"$PWD"}/.pgsocket
mkdir -p "$PGDATA"
mkdir -p "$PGHOST"
if [ ! -f "$PGDATA/PG_VERSION" ]; then
echo "Initializing database..."
initdb -D "$PGDATA"
fi
echo "Starting postgres..."
pg_ctl -D "$PGDATA" -l logfile -o "-k $PGHOST" start
'')
(pkgs.writeShellScriptBin "pg-stop" ''
export PGDATA=${"$PWD"}/.pgdata
pg_ctl -D "$PGDATA" stop
'')
]; ];
shellHook = ''
export PGDATA=$PWD/.pgdata
'';
}; };
} }
); );

View file

@ -2,12 +2,14 @@ package admin
import ( import (
"context" "context"
"errors"
"log/slog" "log/slog"
"strings" "strings"
"codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/db"
"codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/errorutil"
"codeberg.org/nextgo/dbots/internal/paginate" "codeberg.org/nextgo/dbots/internal/paginate"
"github.com/jackc/pgx/v5"
) )
type Service struct { type Service struct {
@ -18,6 +20,20 @@ func NewService(q *db.Queries) *Service {
return &Service{q: q} return &Service{q: q}
} }
func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) {
bot, err := s.q.GetBot(ctx, id)
if err != nil {
slog.Error("error getting bot", "err", err, "id", id)
if errors.Is(err, pgx.ErrNoRows) {
return nil, errorutil.ErrNotFound.Err
} else {
return nil, err
}
}
return bot, nil
}
func (s *Service) ListBots(ctx context.Context, status db.BotStatus, p paginate.Params) (paginate.Page[*db.Bot], error) { func (s *Service) ListBots(ctx context.Context, status db.BotStatus, p paginate.Params) (paginate.Page[*db.Bot], error) {
status = db.BotStatus(strings.ToLower(status.String())) status = db.BotStatus(strings.ToLower(status.String()))
total, err := s.q.CountBotsByUsername(ctx, db.CountBotsByUsernameParams{ total, err := s.q.CountBotsByUsername(ctx, db.CountBotsByUsernameParams{

View file

@ -5,7 +5,9 @@ import (
"codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/db"
"codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/errorutil"
"codeberg.org/nextgo/dbots/internal/middleware"
"codeberg.org/nextgo/dbots/internal/paginate" "codeberg.org/nextgo/dbots/internal/paginate"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/render" "github.com/go-chi/render"
) )
@ -23,14 +25,24 @@ func NewRouter(q *db.Queries) *Router {
} }
func (r *Router) Routes() http.Handler { func (r *Router) Routes() http.Handler {
r.router.Get("/bots", r.listBots)
r.router.Route("/bots", func(router chi.Router) { r.router.Route("/bots", func(router chi.Router) {
router.Get("/", r.listBots)
router.Route("/{botID}", func(b chi.Router) {
b.Use(middleware.BotContext(r.admin.q))
b.Get("/", r.getBot)
})
}) })
return r.router return r.router
} }
func (r *Router) getBot(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
bot := middleware.GetBot(ctx)
render.JSON(w, req, bot)
}
func (r *Router) listBots(w http.ResponseWriter, req *http.Request) { func (r *Router) listBots(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
status := db.BotStatus(req.URL.Query().Get("s")) status := db.BotStatus(req.URL.Query().Get("s"))

View file

@ -11,6 +11,7 @@ import (
"codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/errorutil"
"codeberg.org/nextgo/dbots/internal/middleware" "codeberg.org/nextgo/dbots/internal/middleware"
"codeberg.org/nextgo/dbots/internal/paginate" "codeberg.org/nextgo/dbots/internal/paginate"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
) )
@ -153,7 +154,17 @@ func validateCoOwner(id, mainOwnerID string) error {
} }
func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) { func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) {
return s.q.GetBot(ctx, id) bot, err := s.q.GetBot(ctx, id)
if err != nil {
slog.Error("error getting bot", "err", err, "id", id)
if errors.Is(err, pgx.ErrNoRows) {
return nil, errorutil.ErrNotFound.Err
} else {
return nil, err
}
}
return bot, nil
} }
func (s *Service) List( func (s *Service) List(

View file

@ -1,7 +1,6 @@
package bot package bot
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
@ -14,10 +13,6 @@ import (
"github.com/go-chi/render" "github.com/go-chi/render"
) )
type contextKey string
const botKey contextKey = "bot"
type Router struct { type Router struct {
bots *Service bots *Service
router chi.Router router chi.Router
@ -34,7 +29,7 @@ func (r *Router) Routes() http.Handler {
r.router.Get("/", r.listBots) // todo: deprecate this r.router.Get("/", r.listBots) // todo: deprecate this
r.router.With(middleware.AuthGuardMiddleware).Post("/", r.submitBot) r.router.With(middleware.AuthGuardMiddleware).Post("/", r.submitBot)
r.router.Route("/{botID}", func(router chi.Router) { r.router.Route("/{botID}", func(router chi.Router) {
router.Use(r.BotContext) router.Use(middleware.BotContext(r.bots.q))
router.With(r.BotCache).Get("/", r.getBot) router.With(r.BotCache).Get("/", r.getBot)
router.Route("/co-owners", func(c chi.Router) { router.Route("/co-owners", func(c chi.Router) {
c.Use(middleware.AuthGuardMiddleware) c.Use(middleware.AuthGuardMiddleware)
@ -57,10 +52,13 @@ func (r *Router) submitBot(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
bot, err := r.bots.Submit(ctx, data) bot, err := r.bots.Submit(ctx, data)
if errors.Is(err, errorutil.ErrBotAlreadyExists) { if err != nil {
render.Render(w, req, errorutil.ErrInvalidRequest(err)) if errors.Is(err, errorutil.ErrBotAlreadyExists) {
} else { render.Render(w, req, errorutil.ErrInvalidRequest(err))
render.Render(w, req, errorutil.ErrInternal(err)) } else {
render.Render(w, req, errorutil.ErrInternal(err))
}
return
} }
render.Status(req, http.StatusCreated) render.Status(req, http.StatusCreated)
@ -69,12 +67,7 @@ func (r *Router) submitBot(w http.ResponseWriter, req *http.Request) {
func (r *Router) getBot(w http.ResponseWriter, req *http.Request) { func (r *Router) getBot(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
bot := middleware.GetBot(ctx)
bot, ok := ctx.Value(botKey).(*db.Bot)
if !ok {
render.Render(w, req, errorutil.ErrInvalidRequest(nil))
return
}
render.JSON(w, req, bot) render.JSON(w, req, bot)
} }
@ -95,7 +88,7 @@ func (r *Router) listBots(w http.ResponseWriter, req *http.Request) {
func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) { func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
bot := ctx.Value(botKey).(*db.Bot) bot := middleware.GetBot(ctx)
owners, err := r.bots.ListCoOwners(ctx, bot.ID) owners, err := r.bots.ListCoOwners(ctx, bot.ID)
if err != nil { if err != nil {
@ -108,7 +101,7 @@ func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) {
func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) { func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
bot := ctx.Value(botKey).(*db.Bot) bot := middleware.GetBot(ctx)
userID := chi.URLParam(req, "userID") userID := chi.URLParam(req, "userID")
if err := r.bots.AddCoOwner(ctx, bot.ID, userID); err != nil { if err := r.bots.AddCoOwner(ctx, bot.ID, userID); err != nil {
@ -128,7 +121,7 @@ func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) {
func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) { func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) {
ctx := req.Context() ctx := req.Context()
bot := ctx.Value(botKey).(*db.Bot) bot := middleware.GetBot(ctx)
userID := chi.URLParam(req, "userID") userID := chi.URLParam(req, "userID")
if err := r.bots.RemoveCoOwner(ctx, bot.ID, userID); err != nil { if err := r.bots.RemoveCoOwner(ctx, bot.ID, userID); err != nil {
@ -146,21 +139,6 @@ func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) {
render.NoContent(w, req) render.NoContent(w, req)
} }
func (r *Router) BotContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
botID := chi.URLParam(req, "botID")
bot, err := r.bots.Get(ctx, botID)
if err != nil {
render.Render(w, req, errorutil.ErrNotFound)
return
}
ctx = context.WithValue(ctx, botKey, bot)
next.ServeHTTP(w, req.WithContext(ctx))
})
}
func (r *Router) BotCache(next http.Handler) http.Handler { func (r *Router) BotCache(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Cache-Control", "max-age=3600") w.Header().Add("Cache-Control", "max-age=3600")

View file

@ -0,0 +1,45 @@
package middleware
import (
"context"
"errors"
"net/http"
"codeberg.org/nextgo/dbots/internal/db"
"codeberg.org/nextgo/dbots/internal/errorutil"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/jackc/pgx/v5"
)
const botKey contextKey = "bot"
func BotContext(q *db.Queries) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
botID := chi.URLParam(req, "botID")
bot, err := q.GetBot(ctx, botID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
render.Render(w, req, errorutil.ErrNotFound)
} else {
render.Render(w, req, errorutil.ErrInternal(err))
}
return
}
ctx = context.WithValue(ctx, botKey, bot)
next.ServeHTTP(w, req.WithContext(ctx))
})
}
}
func GetBot(ctx context.Context) *db.Bot {
bot, ok := ctx.Value(botKey).(*db.Bot)
if !ok {
return nil
}
return bot
}

View file

@ -24,9 +24,7 @@ type Server struct {
func NewServer(queries *db.Queries) *Server { func NewServer(queries *db.Queries) *Server {
router := chi.NewMux() router := chi.NewMux()
router.Use(httplog.RequestLogger(slog.Default(), &httplog.Options{ router.Use(httplog.RequestLogger(slog.Default(), &httplog.Options{}))
RecoverPanics: true,
}))
router.Use(middleware.Recoverer) router.Use(middleware.Recoverer)
router.Use(middleware.RequestID) router.Use(middleware.RequestID)
router.Use(middleware.RealIP) router.Use(middleware.RealIP)