package main import ( "bytes" "context" "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "fmt" "log" "math/big" "net/http" "net/url" "os" "strings" "time" "github.com/ryanchen/bbq/db" ) type contextKey string const userContextKey contextKey = "user" func (s *Server) currentUser(r *http.Request) *db.User { if !s.features.Auth { return nil } u, _ := r.Context().Value(userContextKey).(*db.User) return u } // sessionMiddleware loads the user from the session cookie into context. func (s *Server) sessionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !s.features.Auth { next.ServeHTTP(w, r) return } cookie, err := r.Cookie("session") if err != nil { next.ServeHTTP(w, r) return } sess, err := s.q.GetSession(r.Context(), cookie.Value) if err != nil { next.ServeHTTP(w, r) return } var u db.User row := s.db.QueryRowContext(r.Context(), "SELECT id, phone, email, name, created_at FROM users WHERE id = ?", sess.UserID) if err := row.Scan(&u.ID, &u.Phone, &u.Email, &u.Name, &u.CreatedAt); err != nil { next.ServeHTTP(w, r) return } ctx := context.WithValue(r.Context(), userContextKey, &u) next.ServeHTTP(w, r.WithContext(ctx)) }) } // requireAuth redirects to /login if not logged in. func (s *Server) requireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if s.currentUser(r) == nil { http.Redirect(w, r, "/login", http.StatusSeeOther) return } next.ServeHTTP(w, r) }) } // isEmail returns true if the input looks like an email address. func isEmail(s string) bool { return strings.Contains(s, "@") } func (s *Server) handleLoginPage(w http.ResponseWriter, r *http.Request) { if s.currentUser(r) != nil { http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{ "Step": "identify", "AuthEnabled": true, }) } func (s *Server) handleLoginSubmit(w http.ResponseWriter, r *http.Request) { raw := strings.TrimSpace(r.FormValue("identifier")) if raw == "" { http.Redirect(w, r, "/login", http.StatusSeeOther) return } var identifier string var method string // "email" or "phone" if isEmail(raw) { identifier = strings.ToLower(raw) method = "email" } else { identifier = normalizePhone(raw) method = "phone" if identifier == "" { http.Redirect(w, r, "/login", http.StatusSeeOther) return } } code := generateCode() expiresAt := time.Now().UTC().Add(10 * time.Minute) s.q.CreateVerificationCode(r.Context(), db.CreateVerificationCodeParams{ Identifier: identifier, Code: code, ExpiresAt: expiresAt, }) if method == "email" { if err := sendVerificationEmail(identifier, code); err != nil { log.Printf("failed to send verification email to %s: %v", identifier, err) } } else { if err := sendVerificationSMS(identifier, code); err != nil { log.Printf("failed to send verification SMS to %s: %v", identifier, err) } } pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{ "Step": "code", "Identifier": identifier, "Method": method, "AuthEnabled": true, }) } func (s *Server) handleVerifyCode(w http.ResponseWriter, r *http.Request) { identifier := strings.TrimSpace(r.FormValue("identifier")) method := r.FormValue("method") code := strings.TrimSpace(r.FormValue("code")) vc, err := s.q.GetVerificationCode(r.Context(), db.GetVerificationCodeParams{ Identifier: identifier, Code: code, }) if err != nil { pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{ "Step": "code", "Identifier": identifier, "Method": method, "Error": "Invalid or expired code. Try again.", "AuthEnabled": true, }) return } s.q.MarkVerificationCodeUsed(r.Context(), vc.ID) // Get or create user var user db.User if method == "email" { user, err = s.q.GetUserByEmail(r.Context(), sql.NullString{String: identifier, Valid: true}) if err == sql.ErrNoRows { user, err = s.q.CreateUserByEmail(r.Context(), sql.NullString{String: identifier, Valid: true}) } } else { user, err = s.q.GetUserByPhone(r.Context(), sql.NullString{String: identifier, Valid: true}) if err == sql.ErrNoRows { user, err = s.q.CreateUserByPhone(r.Context(), sql.NullString{String: identifier, Valid: true}) } } if err != nil { log.Printf("user lookup/create: %v", err) http.Error(w, "Internal error", 500) return } // Create session token := generateSessionToken() expiresAt := time.Now().UTC().Add(30 * 24 * time.Hour) // 30 days s.q.CreateSession(r.Context(), db.CreateSessionParams{ Token: token, UserID: user.ID, ExpiresAt: expiresAt, }) http.SetCookie(w, &http.Cookie{ Name: "session", Value: token, Path: "/", Expires: expiresAt, HttpOnly: true, SameSite: http.SameSiteLaxMode, }) // If user has no name, send them to set it if user.Name == "" { http.Redirect(w, r, "/account/name", http.StatusSeeOther) return } http.Redirect(w, r, "/dashboard", http.StatusSeeOther) } func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session") if err == nil { s.q.DeleteSession(r.Context(), cookie.Value) } http.SetCookie(w, &http.Cookie{ Name: "session", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, }) http.Redirect(w, r, "/", http.StatusSeeOther) } func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) { user := s.currentUser(r) events, err := s.q.ListEventsByUser(r.Context(), sql.NullInt64{Int64: user.ID, Valid: true}) if err != nil { log.Printf("list events: %v", err) http.Error(w, "Internal error", 500) return } pageTmpl["dashboard"].ExecuteTemplate(w, "layout", map[string]any{ "User": user, "Events": events, "AuthEnabled": true, }) } func (s *Server) handleNamePage(w http.ResponseWriter, r *http.Request) { user := s.currentUser(r) pageTmpl["name"].ExecuteTemplate(w, "layout", map[string]any{ "User": user, "AuthEnabled": true, }) } func (s *Server) handleNameSubmit(w http.ResponseWriter, r *http.Request) { user := s.currentUser(r) name := strings.TrimSpace(r.FormValue("name")) if name == "" { http.Redirect(w, r, "/account/name", http.StatusSeeOther) return } s.q.UpdateUserName(r.Context(), db.UpdateUserNameParams{ Name: name, ID: user.ID, }) http.Redirect(w, r, "/dashboard", http.StatusSeeOther) } func generateCode() string { n, _ := rand.Int(rand.Reader, big.NewInt(1000000)) return fmt.Sprintf("%06d", n.Int64()) } func generateSessionToken() string { b := make([]byte, 32) rand.Read(b) return hex.EncodeToString(b) } // normalizePhone strips non-digit chars and prepends +1 if no country code. func normalizePhone(raw string) string { var digits strings.Builder for _, c := range strings.TrimSpace(raw) { if c >= '0' && c <= '9' { digits.WriteRune(c) } } d := digits.String() if d == "" { return "" } // If 10 digits, assume US and prepend 1 if len(d) == 10 { d = "1" + d } return "+" + d } func sendVerificationSMS(to, code string) error { sid := os.Getenv("TWILIO_ACCOUNT_SID") token := os.Getenv("TWILIO_AUTH_TOKEN") from := os.Getenv("TWILIO_FROM_NUMBER") if sid == "" || token == "" || from == "" { log.Printf("Twilio not configured — code for %s is: %s", to, code) return nil } body := fmt.Sprintf("Your bbq login code: %s", code) apiURL := fmt.Sprintf("https://api.twilio.com/2010-04-01/Accounts/%s/Messages.json", sid) data := url.Values{} data.Set("To", to) data.Set("From", from) data.Set("Body", body) req, _ := http.NewRequest("POST", apiURL, strings.NewReader(data.Encode())) req.SetBasicAuth(sid, token) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 400 { return fmt.Errorf("twilio API returned %d", resp.StatusCode) } return nil } func sendVerificationEmail(to, code string) error { apiKey := os.Getenv("RESEND_API_KEY") if apiKey == "" { log.Printf("RESEND_API_KEY not set — code for %s is: %s", to, code) return nil } fromAddr := os.Getenv("BBQ_FROM_EMAIL") if fromAddr == "" { fromAddr = "bbq " } payload := map[string]any{ "from": fromAddr, "to": []string{to}, "subject": fmt.Sprintf("Your login code: %s", code), "html": fmt.Sprintf(`

Your login code

%s

This code expires in 10 minutes.

`, code), } body, _ := json.Marshal(payload) req, _ := http.NewRequest("POST", "https://api.resend.com/emails", bytes.NewReader(body)) req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode >= 400 { return fmt.Errorf("resend API returned %d", resp.StatusCode) } return nil }