Harden inputs: sanitize lengths, request size limits, caps on slots/rsvps

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-15 10:28:06 -04:00
parent d51e7fe867
commit ef3aa3e9c3
+66 -19
View File
@@ -15,6 +15,23 @@ import (
"github.com/ryanchen/bbq/db" "github.com/ryanchen/bbq/db"
) )
const (
maxFieldLen = 200
maxNoteLen = 500
maxSlots = 20
maxRsvps = 200
maxClaims = 50
maxMaxClaims = 50
)
func sanitize(s string, maxLen int) string {
s = strings.TrimSpace(s)
if len(s) > maxLen {
s = s[:maxLen]
}
return s
}
func randomToken() string { func randomToken() string {
b := make([]byte, 16) b := make([]byte, 16)
rand.Read(b) rand.Read(b)
@@ -34,11 +51,12 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) handleCreateEvent(w http.ResponseWriter, r *http.Request) { func (s *Server) handleCreateEvent(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 32*1024)
r.ParseForm() r.ParseForm()
title := strings.TrimSpace(r.FormValue("title")) title := sanitize(r.FormValue("title"), maxFieldLen)
date := strings.TrimSpace(r.FormValue("date")) date := sanitize(r.FormValue("date"), maxFieldLen)
time_ := strings.TrimSpace(r.FormValue("time")) time_ := sanitize(r.FormValue("time"), maxFieldLen)
location := strings.TrimSpace(r.FormValue("location")) location := sanitize(r.FormValue("location"), maxFieldLen)
if title == "" { if title == "" {
http.Error(w, "Title is required", http.StatusBadRequest) http.Error(w, "Title is required", http.StatusBadRequest)
@@ -56,28 +74,35 @@ func (s *Server) handleCreateEvent(w http.ResponseWriter, r *http.Request) {
return return
} }
// Parse slot fields: slots like "drinks", "salad", etc.
slotNames := r.Form["slot_name"] slotNames := r.Form["slot_name"]
slotEmojis := r.Form["slot_emoji"] slotEmojis := r.Form["slot_emoji"]
slotMaxes := r.Form["slot_max"] slotMaxes := r.Form["slot_max"]
created := 0
for i, name := range slotNames { for i, name := range slotNames {
name = strings.TrimSpace(name) if created >= maxSlots {
break
}
name = sanitize(name, maxFieldLen)
if name == "" { if name == "" {
continue continue
} }
emoji := "" emoji := ""
if i < len(slotEmojis) { if i < len(slotEmojis) {
emoji = strings.TrimSpace(slotEmojis[i]) emoji = sanitize(slotEmojis[i], 32)
} }
maxClaims := int64(1) mc := int64(1)
if i < len(slotMaxes) { if i < len(slotMaxes) {
if v, err := strconv.ParseInt(slotMaxes[i], 10, 64); err == nil && v > 0 { if v, err := strconv.ParseInt(slotMaxes[i], 10, 64); err == nil && v > 0 {
maxClaims = v mc = v
} }
} }
if mc > maxMaxClaims {
mc = maxMaxClaims
}
s.q.CreateSlot(r.Context(), db.CreateSlotParams{ s.q.CreateSlot(r.Context(), db.CreateSlotParams{
EventID: event.ID, Name: name, Emoji: emoji, MaxClaims: maxClaims, SortOrder: int64(i), EventID: event.ID, Name: name, Emoji: emoji, MaxClaims: mc, SortOrder: int64(i),
}) })
created++
} }
http.Redirect(w, r, fmt.Sprintf("/e/%s/admin/%s", event.Slug, event.AdminToken), http.StatusSeeOther) http.Redirect(w, r, fmt.Sprintf("/e/%s/admin/%s", event.Slug, event.AdminToken), http.StatusSeeOther)
@@ -188,6 +213,7 @@ func (s *Server) handleSlotsPartial(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleClaim(w http.ResponseWriter, r *http.Request) { func (s *Server) handleClaim(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug") slug := chi.URLParam(r, "slug")
r.Body = http.MaxBytesReader(w, r.Body, 8*1024)
r.ParseForm() r.ParseForm()
slotID, err := strconv.ParseInt(r.FormValue("slot_id"), 10, 64) slotID, err := strconv.ParseInt(r.FormValue("slot_id"), 10, 64)
@@ -195,12 +221,12 @@ func (s *Server) handleClaim(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid slot", http.StatusBadRequest) http.Error(w, "Invalid slot", http.StatusBadRequest)
return return
} }
name := strings.TrimSpace(r.FormValue("name")) name := sanitize(r.FormValue("name"), maxFieldLen)
if name == "" { if name == "" {
http.Error(w, "Name is required", http.StatusBadRequest) http.Error(w, "Name is required", http.StatusBadRequest)
return return
} }
note := strings.TrimSpace(r.FormValue("note")) note := sanitize(r.FormValue("note"), maxNoteLen)
// Check slot exists and belongs to this event // Check slot exists and belongs to this event
slot, err := s.q.GetSlot(r.Context(), slotID) slot, err := s.q.GetSlot(r.Context(), slotID)
@@ -274,14 +300,15 @@ func (s *Server) handleUnclaim(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleRsvp(w http.ResponseWriter, r *http.Request) { func (s *Server) handleRsvp(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug") slug := chi.URLParam(r, "slug")
r.Body = http.MaxBytesReader(w, r.Body, 8*1024)
r.ParseForm() r.ParseForm()
name := strings.TrimSpace(r.FormValue("name")) name := sanitize(r.FormValue("name"), maxFieldLen)
if name == "" { if name == "" {
http.Error(w, "Name is required", http.StatusBadRequest) http.Error(w, "Name is required", http.StatusBadRequest)
return return
} }
note := strings.TrimSpace(r.FormValue("note")) note := sanitize(r.FormValue("note"), maxNoteLen)
event, err := s.q.GetEventBySlug(r.Context(), slug) event, err := s.q.GetEventBySlug(r.Context(), slug)
if err != nil { if err != nil {
@@ -289,6 +316,16 @@ func (s *Server) handleRsvp(w http.ResponseWriter, r *http.Request) {
return return
} }
count, err := s.q.CountRsvps(r.Context(), event.ID)
if err != nil {
http.Error(w, "error", http.StatusInternalServerError)
return
}
if count >= maxRsvps {
http.Error(w, "RSVP list is full", http.StatusConflict)
return
}
_, err = s.q.CreateRsvp(r.Context(), db.CreateRsvpParams{ _, err = s.q.CreateRsvp(r.Context(), db.CreateRsvpParams{
EventID: event.ID, Name: name, Note: note, EventID: event.ID, Name: name, Note: note,
}) })
@@ -396,12 +433,16 @@ func (s *Server) handleCreateSlot(w http.ResponseWriter, r *http.Request) {
return return
} }
r.Body = http.MaxBytesReader(w, r.Body, 8*1024)
r.ParseForm() r.ParseForm()
name := strings.TrimSpace(r.FormValue("name")) name := sanitize(r.FormValue("name"), maxFieldLen)
emoji := strings.TrimSpace(r.FormValue("emoji")) emoji := sanitize(r.FormValue("emoji"), 32)
maxClaims := int64(1) mc := int64(1)
if v, err := strconv.ParseInt(r.FormValue("max_claims"), 10, 64); err == nil && v > 0 { if v, err := strconv.ParseInt(r.FormValue("max_claims"), 10, 64); err == nil && v > 0 {
maxClaims = v mc = v
}
if mc > maxMaxClaims {
mc = maxMaxClaims
} }
if name == "" { if name == "" {
@@ -409,8 +450,14 @@ func (s *Server) handleCreateSlot(w http.ResponseWriter, r *http.Request) {
return return
} }
slots, _ := s.q.ListSlots(r.Context(), event.ID)
if len(slots) >= maxSlots {
http.Error(w, "Too many slots", http.StatusConflict)
return
}
_, err = s.q.CreateSlot(r.Context(), db.CreateSlotParams{ _, err = s.q.CreateSlot(r.Context(), db.CreateSlotParams{
EventID: event.ID, Name: name, Emoji: emoji, MaxClaims: maxClaims, SortOrder: 999, EventID: event.ID, Name: name, Emoji: emoji, MaxClaims: mc, SortOrder: 999,
}) })
if err != nil { if err != nil {
http.Error(w, "Failed", http.StatusInternalServerError) http.Error(w, "Failed", http.StatusInternalServerError)