a0c4b28d1e
Auto-detect whether the user entered an email or phone number. Email sends via Resend, phone sends via Twilio SMS. Users table has nullable phone and email columns; verification_codes uses a generic identifier field. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
354 lines
9.2 KiB
Go
354 lines
9.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"math/big"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ryanchen/bbq/db"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const userContextKey contextKey = "user"
|
|
|
|
func (s *Server) currentUser(r *http.Request) *db.User {
|
|
if !s.features.Auth {
|
|
return nil
|
|
}
|
|
u, _ := r.Context().Value(userContextKey).(*db.User)
|
|
return u
|
|
}
|
|
|
|
// sessionMiddleware loads the user from the session cookie into context.
|
|
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !s.features.Auth {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
cookie, err := r.Cookie("session")
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
sess, err := s.q.GetSession(r.Context(), cookie.Value)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
var u db.User
|
|
row := s.db.QueryRowContext(r.Context(),
|
|
"SELECT id, phone, email, name, created_at FROM users WHERE id = ?", sess.UserID)
|
|
if err := row.Scan(&u.ID, &u.Phone, &u.Email, &u.Name, &u.CreatedAt); err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
ctx := context.WithValue(r.Context(), userContextKey, &u)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// requireAuth redirects to /login if not logged in.
|
|
func (s *Server) requireAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if s.currentUser(r) == nil {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// isEmail returns true if the input looks like an email address.
|
|
func isEmail(s string) bool {
|
|
return strings.Contains(s, "@")
|
|
}
|
|
|
|
func (s *Server) handleLoginPage(w http.ResponseWriter, r *http.Request) {
|
|
if s.currentUser(r) != nil {
|
|
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
|
return
|
|
}
|
|
pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{
|
|
"Step": "identify",
|
|
"AuthEnabled": true,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleLoginSubmit(w http.ResponseWriter, r *http.Request) {
|
|
raw := strings.TrimSpace(r.FormValue("identifier"))
|
|
if raw == "" {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
var identifier string
|
|
var method string // "email" or "phone"
|
|
if isEmail(raw) {
|
|
identifier = strings.ToLower(raw)
|
|
method = "email"
|
|
} else {
|
|
identifier = normalizePhone(raw)
|
|
method = "phone"
|
|
if identifier == "" {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
}
|
|
|
|
code := generateCode()
|
|
expiresAt := time.Now().Add(10 * time.Minute)
|
|
s.q.CreateVerificationCode(r.Context(), db.CreateVerificationCodeParams{
|
|
Identifier: identifier,
|
|
Code: code,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
|
|
if method == "email" {
|
|
if err := sendVerificationEmail(identifier, code); err != nil {
|
|
log.Printf("failed to send verification email to %s: %v", identifier, err)
|
|
}
|
|
} else {
|
|
if err := sendVerificationSMS(identifier, code); err != nil {
|
|
log.Printf("failed to send verification SMS to %s: %v", identifier, err)
|
|
}
|
|
}
|
|
|
|
pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{
|
|
"Step": "code",
|
|
"Identifier": identifier,
|
|
"Method": method,
|
|
"AuthEnabled": true,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleVerifyCode(w http.ResponseWriter, r *http.Request) {
|
|
identifier := strings.TrimSpace(r.FormValue("identifier"))
|
|
method := r.FormValue("method")
|
|
code := strings.TrimSpace(r.FormValue("code"))
|
|
|
|
vc, err := s.q.GetVerificationCode(r.Context(), db.GetVerificationCodeParams{
|
|
Identifier: identifier,
|
|
Code: code,
|
|
})
|
|
if err != nil {
|
|
pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{
|
|
"Step": "code",
|
|
"Identifier": identifier,
|
|
"Method": method,
|
|
"Error": "Invalid or expired code. Try again.",
|
|
"AuthEnabled": true,
|
|
})
|
|
return
|
|
}
|
|
|
|
s.q.MarkVerificationCodeUsed(r.Context(), vc.ID)
|
|
|
|
// Get or create user
|
|
var user db.User
|
|
if method == "email" {
|
|
user, err = s.q.GetUserByEmail(r.Context(), sql.NullString{String: identifier, Valid: true})
|
|
if err == sql.ErrNoRows {
|
|
user, err = s.q.CreateUserByEmail(r.Context(), sql.NullString{String: identifier, Valid: true})
|
|
}
|
|
} else {
|
|
user, err = s.q.GetUserByPhone(r.Context(), sql.NullString{String: identifier, Valid: true})
|
|
if err == sql.ErrNoRows {
|
|
user, err = s.q.CreateUserByPhone(r.Context(), sql.NullString{String: identifier, Valid: true})
|
|
}
|
|
}
|
|
if err != nil {
|
|
log.Printf("user lookup/create: %v", err)
|
|
http.Error(w, "Internal error", 500)
|
|
return
|
|
}
|
|
|
|
// Create session
|
|
token := generateSessionToken()
|
|
expiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days
|
|
s.q.CreateSession(r.Context(), db.CreateSessionParams{
|
|
Token: token,
|
|
UserID: user.ID,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session",
|
|
Value: token,
|
|
Path: "/",
|
|
Expires: expiresAt,
|
|
HttpOnly: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
// If user has no name, send them to set it
|
|
if user.Name == "" {
|
|
http.Redirect(w, r, "/account/name", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
|
}
|
|
|
|
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("session")
|
|
if err == nil {
|
|
s.q.DeleteSession(r.Context(), cookie.Value)
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "session",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
})
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
}
|
|
|
|
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
|
user := s.currentUser(r)
|
|
events, err := s.q.ListEventsByUser(r.Context(), sql.NullInt64{Int64: user.ID, Valid: true})
|
|
if err != nil {
|
|
log.Printf("list events: %v", err)
|
|
http.Error(w, "Internal error", 500)
|
|
return
|
|
}
|
|
pageTmpl["dashboard"].ExecuteTemplate(w, "layout", map[string]any{
|
|
"User": user,
|
|
"Events": events,
|
|
"AuthEnabled": true,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleNamePage(w http.ResponseWriter, r *http.Request) {
|
|
user := s.currentUser(r)
|
|
pageTmpl["name"].ExecuteTemplate(w, "layout", map[string]any{
|
|
"User": user,
|
|
"AuthEnabled": true,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleNameSubmit(w http.ResponseWriter, r *http.Request) {
|
|
user := s.currentUser(r)
|
|
name := strings.TrimSpace(r.FormValue("name"))
|
|
if name == "" {
|
|
http.Redirect(w, r, "/account/name", http.StatusSeeOther)
|
|
return
|
|
}
|
|
s.q.UpdateUserName(r.Context(), db.UpdateUserNameParams{
|
|
Name: name,
|
|
ID: user.ID,
|
|
})
|
|
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
|
|
}
|
|
|
|
func generateCode() string {
|
|
n, _ := rand.Int(rand.Reader, big.NewInt(1000000))
|
|
return fmt.Sprintf("%06d", n.Int64())
|
|
}
|
|
|
|
func generateSessionToken() string {
|
|
b := make([]byte, 32)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// normalizePhone strips non-digit chars and prepends +1 if no country code.
|
|
func normalizePhone(raw string) string {
|
|
var digits strings.Builder
|
|
for _, c := range strings.TrimSpace(raw) {
|
|
if c >= '0' && c <= '9' {
|
|
digits.WriteRune(c)
|
|
}
|
|
}
|
|
d := digits.String()
|
|
if d == "" {
|
|
return ""
|
|
}
|
|
// If 10 digits, assume US and prepend 1
|
|
if len(d) == 10 {
|
|
d = "1" + d
|
|
}
|
|
return "+" + d
|
|
}
|
|
|
|
func sendVerificationSMS(to, code string) error {
|
|
sid := os.Getenv("TWILIO_ACCOUNT_SID")
|
|
token := os.Getenv("TWILIO_AUTH_TOKEN")
|
|
from := os.Getenv("TWILIO_FROM_NUMBER")
|
|
|
|
if sid == "" || token == "" || from == "" {
|
|
log.Printf("Twilio not configured — code for %s is: %s", to, code)
|
|
return nil
|
|
}
|
|
|
|
body := fmt.Sprintf("Your bbq login code: %s", code)
|
|
apiURL := fmt.Sprintf("https://api.twilio.com/2010-04-01/Accounts/%s/Messages.json", sid)
|
|
|
|
data := url.Values{}
|
|
data.Set("To", to)
|
|
data.Set("From", from)
|
|
data.Set("Body", body)
|
|
|
|
req, _ := http.NewRequest("POST", apiURL, strings.NewReader(data.Encode()))
|
|
req.SetBasicAuth(sid, token)
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 400 {
|
|
return fmt.Errorf("twilio API returned %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sendVerificationEmail(to, code string) error {
|
|
apiKey := os.Getenv("RESEND_API_KEY")
|
|
if apiKey == "" {
|
|
log.Printf("RESEND_API_KEY not set — code for %s is: %s", to, code)
|
|
return nil
|
|
}
|
|
|
|
fromAddr := os.Getenv("BBQ_FROM_EMAIL")
|
|
if fromAddr == "" {
|
|
fromAddr = "bbq <noreply@bbq.torrtle.co>"
|
|
}
|
|
|
|
payload := map[string]any{
|
|
"from": fromAddr,
|
|
"to": []string{to},
|
|
"subject": fmt.Sprintf("Your login code: %s", code),
|
|
"html": fmt.Sprintf(`<div style="font-family:sans-serif;max-width:400px;margin:0 auto;padding:40px 20px;"><h2 style="margin:0 0 16px;">Your login code</h2><div style="font-size:32px;font-weight:bold;letter-spacing:8px;background:#f5f0e8;border:2px solid #1a1a1a;padding:20px;text-align:center;margin:16px 0;">%s</div><p style="color:#555;font-size:14px;">This code expires in 10 minutes.</p></div>`, code),
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req, _ := http.NewRequest("POST", "https://api.resend.com/emails", bytes.NewReader(body))
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 400 {
|
|
return fmt.Errorf("resend API returned %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|