From 9ef5e33b82da60db9ba64ffee22a3cac674b8ff0 Mon Sep 17 00:00:00 2001 From: Elisiei Yehorov Date: Fri, 17 Apr 2026 23:46:03 +0200 Subject: [PATCH] refactor: bot context --- .gitignore | 5 ++++- flake.nix | 27 ++++++++++++++++++++++ internal/admin/admin.go | 16 +++++++++++++ internal/admin/router.go | 16 +++++++++++-- internal/bot/bot.go | 13 ++++++++++- internal/bot/router.go | 46 ++++++++++---------------------------- internal/middleware/bot.go | 45 +++++++++++++++++++++++++++++++++++++ internal/server/server.go | 4 +--- 8 files changed, 131 insertions(+), 41 deletions(-) create mode 100644 internal/middleware/bot.go diff --git a/.gitignore b/.gitignore index 50b5dca..b97556f 100644 --- a/.gitignore +++ b/.gitignore @@ -26,8 +26,11 @@ go.work.sum # env file .env -.direnv +.direnv/ +# pg +.pgdata/ +.pgsocket/ # Editor/IDE # .idea/ # .vscode/ diff --git a/flake.nix b/flake.nix index ac1d74a..be96a6b 100644 --- a/flake.nix +++ b/flake.nix @@ -14,6 +14,7 @@ "x86_64-darwin" "aarch64-darwin" ]; + forEachSupportedSystem = f: inputs.nixpkgs.lib.genAttrs supportedSystems ( @@ -39,7 +40,33 @@ go sqlc goose + postgresql + + (pkgs.writeShellScriptBin "pg-start" '' + export PGDATA=${"$PWD"}/.pgdata + export PGHOST=${"$PWD"}/.pgsocket + + mkdir -p "$PGDATA" + mkdir -p "$PGHOST" + + if [ ! -f "$PGDATA/PG_VERSION" ]; then + echo "Initializing database..." + initdb -D "$PGDATA" + fi + + echo "Starting postgres..." + pg_ctl -D "$PGDATA" -l logfile -o "-k $PGHOST" start + '') + + (pkgs.writeShellScriptBin "pg-stop" '' + export PGDATA=${"$PWD"}/.pgdata + pg_ctl -D "$PGDATA" stop + '') ]; + + shellHook = '' + export PGDATA=$PWD/.pgdata + ''; }; } ); diff --git a/internal/admin/admin.go b/internal/admin/admin.go index 372148c..634951d 100644 --- a/internal/admin/admin.go +++ b/internal/admin/admin.go @@ -2,12 +2,14 @@ package admin import ( "context" + "errors" "log/slog" "strings" "codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/paginate" + "github.com/jackc/pgx/v5" ) type Service struct { @@ -18,6 +20,20 @@ func NewService(q *db.Queries) *Service { return &Service{q: q} } +func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) { + bot, err := s.q.GetBot(ctx, id) + if err != nil { + slog.Error("error getting bot", "err", err, "id", id) + if errors.Is(err, pgx.ErrNoRows) { + return nil, errorutil.ErrNotFound.Err + } else { + return nil, err + } + } + + return bot, nil +} + func (s *Service) ListBots(ctx context.Context, status db.BotStatus, p paginate.Params) (paginate.Page[*db.Bot], error) { status = db.BotStatus(strings.ToLower(status.String())) total, err := s.q.CountBotsByUsername(ctx, db.CountBotsByUsernameParams{ diff --git a/internal/admin/router.go b/internal/admin/router.go index 8bcba13..78af96a 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -5,7 +5,9 @@ import ( "codeberg.org/nextgo/dbots/internal/db" "codeberg.org/nextgo/dbots/internal/errorutil" + "codeberg.org/nextgo/dbots/internal/middleware" "codeberg.org/nextgo/dbots/internal/paginate" + "github.com/go-chi/chi/v5" "github.com/go-chi/render" ) @@ -23,14 +25,24 @@ func NewRouter(q *db.Queries) *Router { } func (r *Router) Routes() http.Handler { - r.router.Get("/bots", r.listBots) r.router.Route("/bots", func(router chi.Router) { - + router.Get("/", r.listBots) + router.Route("/{botID}", func(b chi.Router) { + b.Use(middleware.BotContext(r.admin.q)) + b.Get("/", r.getBot) + }) }) return r.router } +func (r *Router) getBot(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + bot := middleware.GetBot(ctx) + + render.JSON(w, req, bot) +} + func (r *Router) listBots(w http.ResponseWriter, req *http.Request) { ctx := req.Context() status := db.BotStatus(req.URL.Query().Get("s")) diff --git a/internal/bot/bot.go b/internal/bot/bot.go index 893e708..2417bb9 100644 --- a/internal/bot/bot.go +++ b/internal/bot/bot.go @@ -11,6 +11,7 @@ import ( "codeberg.org/nextgo/dbots/internal/errorutil" "codeberg.org/nextgo/dbots/internal/middleware" "codeberg.org/nextgo/dbots/internal/paginate" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) @@ -153,7 +154,17 @@ func validateCoOwner(id, mainOwnerID string) error { } func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) { - return s.q.GetBot(ctx, id) + bot, err := s.q.GetBot(ctx, id) + if err != nil { + slog.Error("error getting bot", "err", err, "id", id) + if errors.Is(err, pgx.ErrNoRows) { + return nil, errorutil.ErrNotFound.Err + } else { + return nil, err + } + } + + return bot, nil } func (s *Service) List( diff --git a/internal/bot/router.go b/internal/bot/router.go index f7a56d8..5b1070c 100644 --- a/internal/bot/router.go +++ b/internal/bot/router.go @@ -1,7 +1,6 @@ package bot import ( - "context" "errors" "net/http" @@ -14,10 +13,6 @@ import ( "github.com/go-chi/render" ) -type contextKey string - -const botKey contextKey = "bot" - type Router struct { bots *Service router chi.Router @@ -34,7 +29,7 @@ func (r *Router) Routes() http.Handler { r.router.Get("/", r.listBots) // todo: deprecate this r.router.With(middleware.AuthGuardMiddleware).Post("/", r.submitBot) r.router.Route("/{botID}", func(router chi.Router) { - router.Use(r.BotContext) + router.Use(middleware.BotContext(r.bots.q)) router.With(r.BotCache).Get("/", r.getBot) router.Route("/co-owners", func(c chi.Router) { c.Use(middleware.AuthGuardMiddleware) @@ -57,10 +52,13 @@ func (r *Router) submitBot(w http.ResponseWriter, req *http.Request) { ctx := req.Context() bot, err := r.bots.Submit(ctx, data) - if errors.Is(err, errorutil.ErrBotAlreadyExists) { - render.Render(w, req, errorutil.ErrInvalidRequest(err)) - } else { - render.Render(w, req, errorutil.ErrInternal(err)) + if err != nil { + if errors.Is(err, errorutil.ErrBotAlreadyExists) { + render.Render(w, req, errorutil.ErrInvalidRequest(err)) + } else { + render.Render(w, req, errorutil.ErrInternal(err)) + } + return } render.Status(req, http.StatusCreated) @@ -69,12 +67,7 @@ func (r *Router) submitBot(w http.ResponseWriter, req *http.Request) { func (r *Router) getBot(w http.ResponseWriter, req *http.Request) { ctx := req.Context() - - bot, ok := ctx.Value(botKey).(*db.Bot) - if !ok { - render.Render(w, req, errorutil.ErrInvalidRequest(nil)) - return - } + bot := middleware.GetBot(ctx) render.JSON(w, req, bot) } @@ -95,7 +88,7 @@ func (r *Router) listBots(w http.ResponseWriter, req *http.Request) { func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) { ctx := req.Context() - bot := ctx.Value(botKey).(*db.Bot) + bot := middleware.GetBot(ctx) owners, err := r.bots.ListCoOwners(ctx, bot.ID) if err != nil { @@ -108,7 +101,7 @@ func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) { func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) { ctx := req.Context() - bot := ctx.Value(botKey).(*db.Bot) + bot := middleware.GetBot(ctx) userID := chi.URLParam(req, "userID") if err := r.bots.AddCoOwner(ctx, bot.ID, userID); err != nil { @@ -128,7 +121,7 @@ func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) { func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) { ctx := req.Context() - bot := ctx.Value(botKey).(*db.Bot) + bot := middleware.GetBot(ctx) userID := chi.URLParam(req, "userID") if err := r.bots.RemoveCoOwner(ctx, bot.ID, userID); err != nil { @@ -146,21 +139,6 @@ func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) { render.NoContent(w, req) } -func (r *Router) BotContext(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ctx := req.Context() - botID := chi.URLParam(req, "botID") - bot, err := r.bots.Get(ctx, botID) - if err != nil { - render.Render(w, req, errorutil.ErrNotFound) - return - } - - ctx = context.WithValue(ctx, botKey, bot) - next.ServeHTTP(w, req.WithContext(ctx)) - }) -} - func (r *Router) BotCache(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Header().Add("Cache-Control", "max-age=3600") diff --git a/internal/middleware/bot.go b/internal/middleware/bot.go new file mode 100644 index 0000000..1bafab8 --- /dev/null +++ b/internal/middleware/bot.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + + "codeberg.org/nextgo/dbots/internal/db" + "codeberg.org/nextgo/dbots/internal/errorutil" + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/jackc/pgx/v5" +) + +const botKey contextKey = "bot" + +func BotContext(q *db.Queries) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + botID := chi.URLParam(req, "botID") + bot, err := q.GetBot(ctx, botID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + render.Render(w, req, errorutil.ErrNotFound) + } else { + render.Render(w, req, errorutil.ErrInternal(err)) + } + return + } + + ctx = context.WithValue(ctx, botKey, bot) + next.ServeHTTP(w, req.WithContext(ctx)) + }) + } +} + +func GetBot(ctx context.Context) *db.Bot { + bot, ok := ctx.Value(botKey).(*db.Bot) + if !ok { + return nil + } + + return bot +} diff --git a/internal/server/server.go b/internal/server/server.go index b754886..1a1724c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -24,9 +24,7 @@ type Server struct { func NewServer(queries *db.Queries) *Server { router := chi.NewMux() - router.Use(httplog.RequestLogger(slog.Default(), &httplog.Options{ - RecoverPanics: true, - })) + router.Use(httplog.RequestLogger(slog.Default(), &httplog.Options{})) router.Use(middleware.Recoverer) router.Use(middleware.RequestID) router.Use(middleware.RealIP)