refactor: bot context
This commit is contained in:
parent
9df1d0de56
commit
9ef5e33b82
8 changed files with 131 additions and 41 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -26,8 +26,11 @@ go.work.sum
|
|||
|
||||
# env file
|
||||
.env
|
||||
.direnv
|
||||
.direnv/
|
||||
|
||||
# pg
|
||||
.pgdata/
|
||||
.pgsocket/
|
||||
# Editor/IDE
|
||||
# .idea/
|
||||
# .vscode/
|
||||
|
|
|
|||
27
flake.nix
27
flake.nix
|
|
@ -14,6 +14,7 @@
|
|||
"x86_64-darwin"
|
||||
"aarch64-darwin"
|
||||
];
|
||||
|
||||
forEachSupportedSystem =
|
||||
f:
|
||||
inputs.nixpkgs.lib.genAttrs supportedSystems (
|
||||
|
|
@ -39,7 +40,33 @@
|
|||
go
|
||||
sqlc
|
||||
goose
|
||||
postgresql
|
||||
|
||||
(pkgs.writeShellScriptBin "pg-start" ''
|
||||
export PGDATA=${"$PWD"}/.pgdata
|
||||
export PGHOST=${"$PWD"}/.pgsocket
|
||||
|
||||
mkdir -p "$PGDATA"
|
||||
mkdir -p "$PGHOST"
|
||||
|
||||
if [ ! -f "$PGDATA/PG_VERSION" ]; then
|
||||
echo "Initializing database..."
|
||||
initdb -D "$PGDATA"
|
||||
fi
|
||||
|
||||
echo "Starting postgres..."
|
||||
pg_ctl -D "$PGDATA" -l logfile -o "-k $PGHOST" start
|
||||
'')
|
||||
|
||||
(pkgs.writeShellScriptBin "pg-stop" ''
|
||||
export PGDATA=${"$PWD"}/.pgdata
|
||||
pg_ctl -D "$PGDATA" stop
|
||||
'')
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
export PGDATA=$PWD/.pgdata
|
||||
'';
|
||||
};
|
||||
}
|
||||
);
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ package admin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"codeberg.org/nextgo/dbots/internal/db"
|
||||
"codeberg.org/nextgo/dbots/internal/errorutil"
|
||||
"codeberg.org/nextgo/dbots/internal/paginate"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
|
|
@ -18,6 +20,20 @@ func NewService(q *db.Queries) *Service {
|
|||
return &Service{q: q}
|
||||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) {
|
||||
bot, err := s.q.GetBot(ctx, id)
|
||||
if err != nil {
|
||||
slog.Error("error getting bot", "err", err, "id", id)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, errorutil.ErrNotFound.Err
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return bot, nil
|
||||
}
|
||||
|
||||
func (s *Service) ListBots(ctx context.Context, status db.BotStatus, p paginate.Params) (paginate.Page[*db.Bot], error) {
|
||||
status = db.BotStatus(strings.ToLower(status.String()))
|
||||
total, err := s.q.CountBotsByUsername(ctx, db.CountBotsByUsernameParams{
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ import (
|
|||
|
||||
"codeberg.org/nextgo/dbots/internal/db"
|
||||
"codeberg.org/nextgo/dbots/internal/errorutil"
|
||||
"codeberg.org/nextgo/dbots/internal/middleware"
|
||||
"codeberg.org/nextgo/dbots/internal/paginate"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
|
@ -23,14 +25,24 @@ func NewRouter(q *db.Queries) *Router {
|
|||
}
|
||||
|
||||
func (r *Router) Routes() http.Handler {
|
||||
r.router.Get("/bots", r.listBots)
|
||||
r.router.Route("/bots", func(router chi.Router) {
|
||||
|
||||
router.Get("/", r.listBots)
|
||||
router.Route("/{botID}", func(b chi.Router) {
|
||||
b.Use(middleware.BotContext(r.admin.q))
|
||||
b.Get("/", r.getBot)
|
||||
})
|
||||
})
|
||||
|
||||
return r.router
|
||||
}
|
||||
|
||||
func (r *Router) getBot(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
bot := middleware.GetBot(ctx)
|
||||
|
||||
render.JSON(w, req, bot)
|
||||
}
|
||||
|
||||
func (r *Router) listBots(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
status := db.BotStatus(req.URL.Query().Get("s"))
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"codeberg.org/nextgo/dbots/internal/errorutil"
|
||||
"codeberg.org/nextgo/dbots/internal/middleware"
|
||||
"codeberg.org/nextgo/dbots/internal/paginate"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
|
|
@ -153,7 +154,17 @@ func validateCoOwner(id, mainOwnerID string) error {
|
|||
}
|
||||
|
||||
func (s *Service) Get(ctx context.Context, id string) (*db.Bot, error) {
|
||||
return s.q.GetBot(ctx, id)
|
||||
bot, err := s.q.GetBot(ctx, id)
|
||||
if err != nil {
|
||||
slog.Error("error getting bot", "err", err, "id", id)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, errorutil.ErrNotFound.Err
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return bot, nil
|
||||
}
|
||||
|
||||
func (s *Service) List(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package bot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
|
|
@ -14,10 +13,6 @@ import (
|
|||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const botKey contextKey = "bot"
|
||||
|
||||
type Router struct {
|
||||
bots *Service
|
||||
router chi.Router
|
||||
|
|
@ -34,7 +29,7 @@ func (r *Router) Routes() http.Handler {
|
|||
r.router.Get("/", r.listBots) // todo: deprecate this
|
||||
r.router.With(middleware.AuthGuardMiddleware).Post("/", r.submitBot)
|
||||
r.router.Route("/{botID}", func(router chi.Router) {
|
||||
router.Use(r.BotContext)
|
||||
router.Use(middleware.BotContext(r.bots.q))
|
||||
router.With(r.BotCache).Get("/", r.getBot)
|
||||
router.Route("/co-owners", func(c chi.Router) {
|
||||
c.Use(middleware.AuthGuardMiddleware)
|
||||
|
|
@ -57,11 +52,14 @@ func (r *Router) submitBot(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
ctx := req.Context()
|
||||
bot, err := r.bots.Submit(ctx, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, errorutil.ErrBotAlreadyExists) {
|
||||
render.Render(w, req, errorutil.ErrInvalidRequest(err))
|
||||
} else {
|
||||
render.Render(w, req, errorutil.ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
render.Status(req, http.StatusCreated)
|
||||
render.JSON(w, req, bot)
|
||||
|
|
@ -69,12 +67,7 @@ func (r *Router) submitBot(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
func (r *Router) getBot(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
|
||||
bot, ok := ctx.Value(botKey).(*db.Bot)
|
||||
if !ok {
|
||||
render.Render(w, req, errorutil.ErrInvalidRequest(nil))
|
||||
return
|
||||
}
|
||||
bot := middleware.GetBot(ctx)
|
||||
|
||||
render.JSON(w, req, bot)
|
||||
}
|
||||
|
|
@ -95,7 +88,7 @@ func (r *Router) listBots(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
bot := ctx.Value(botKey).(*db.Bot)
|
||||
bot := middleware.GetBot(ctx)
|
||||
|
||||
owners, err := r.bots.ListCoOwners(ctx, bot.ID)
|
||||
if err != nil {
|
||||
|
|
@ -108,7 +101,7 @@ func (r *Router) listCoOwners(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
bot := ctx.Value(botKey).(*db.Bot)
|
||||
bot := middleware.GetBot(ctx)
|
||||
userID := chi.URLParam(req, "userID")
|
||||
|
||||
if err := r.bots.AddCoOwner(ctx, bot.ID, userID); err != nil {
|
||||
|
|
@ -128,7 +121,7 @@ func (r *Router) addCoOwner(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
bot := ctx.Value(botKey).(*db.Bot)
|
||||
bot := middleware.GetBot(ctx)
|
||||
userID := chi.URLParam(req, "userID")
|
||||
|
||||
if err := r.bots.RemoveCoOwner(ctx, bot.ID, userID); err != nil {
|
||||
|
|
@ -146,21 +139,6 @@ func (r *Router) removeCoOwner(w http.ResponseWriter, req *http.Request) {
|
|||
render.NoContent(w, req)
|
||||
}
|
||||
|
||||
func (r *Router) BotContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
botID := chi.URLParam(req, "botID")
|
||||
bot, err := r.bots.Get(ctx, botID)
|
||||
if err != nil {
|
||||
render.Render(w, req, errorutil.ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, botKey, bot)
|
||||
next.ServeHTTP(w, req.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Router) BotCache(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.Header().Add("Cache-Control", "max-age=3600")
|
||||
|
|
|
|||
45
internal/middleware/bot.go
Normal file
45
internal/middleware/bot.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"codeberg.org/nextgo/dbots/internal/db"
|
||||
"codeberg.org/nextgo/dbots/internal/errorutil"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
const botKey contextKey = "bot"
|
||||
|
||||
func BotContext(q *db.Queries) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
botID := chi.URLParam(req, "botID")
|
||||
bot, err := q.GetBot(ctx, botID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
render.Render(w, req, errorutil.ErrNotFound)
|
||||
} else {
|
||||
render.Render(w, req, errorutil.ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, botKey, bot)
|
||||
next.ServeHTTP(w, req.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetBot(ctx context.Context) *db.Bot {
|
||||
bot, ok := ctx.Value(botKey).(*db.Bot)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return bot
|
||||
}
|
||||
|
|
@ -24,9 +24,7 @@ type Server struct {
|
|||
func NewServer(queries *db.Queries) *Server {
|
||||
router := chi.NewMux()
|
||||
|
||||
router.Use(httplog.RequestLogger(slog.Default(), &httplog.Options{
|
||||
RecoverPanics: true,
|
||||
}))
|
||||
router.Use(httplog.RequestLogger(slog.Default(), &httplog.Options{}))
|
||||
router.Use(middleware.Recoverer)
|
||||
router.Use(middleware.RequestID)
|
||||
router.Use(middleware.RealIP)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue