diff --git a/auth.go b/auth.go index 4285310..a06c99c 100644 --- a/auth.go +++ b/auth.go @@ -1,10 +1,12 @@ package main import ( + "bytes" "context" "crypto/rand" "database/sql" "encoding/hex" + "encoding/json" "fmt" "log" "math/big" @@ -48,8 +50,8 @@ func (s *Server) sessionMiddleware(next http.Handler) http.Handler { } var u db.User row := s.db.QueryRowContext(r.Context(), - "SELECT id, phone, name, created_at FROM users WHERE id = ?", sess.UserID) - if err := row.Scan(&u.ID, &u.Phone, &u.Name, &u.CreatedAt); err != nil { + "SELECT id, phone, email, name, created_at FROM users WHERE id = ?", sess.UserID) + if err := row.Scan(&u.ID, &u.Phone, &u.Email, &u.Name, &u.CreatedAt); err != nil { next.ServeHTTP(w, r) return } @@ -69,55 +71,83 @@ func (s *Server) requireAuth(next http.Handler) http.Handler { }) } +// isEmail returns true if the input looks like an email address. +func isEmail(s string) bool { + return strings.Contains(s, "@") +} + func (s *Server) handleLoginPage(w http.ResponseWriter, r *http.Request) { if s.currentUser(r) != nil { http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{ - "Step": "phone", + "Step": "identify", "AuthEnabled": true, }) } func (s *Server) handleLoginSubmit(w http.ResponseWriter, r *http.Request) { - phone := normalizePhone(r.FormValue("phone")) - if phone == "" { + raw := strings.TrimSpace(r.FormValue("identifier")) + if raw == "" { http.Redirect(w, r, "/login", http.StatusSeeOther) return } + var identifier string + var method string // "email" or "phone" + if isEmail(raw) { + identifier = strings.ToLower(raw) + method = "email" + } else { + identifier = normalizePhone(raw) + method = "phone" + if identifier == "" { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + } + code := generateCode() expiresAt := time.Now().Add(10 * time.Minute) s.q.CreateVerificationCode(r.Context(), db.CreateVerificationCodeParams{ - Phone: phone, - Code: code, - ExpiresAt: expiresAt, + Identifier: identifier, + Code: code, + ExpiresAt: expiresAt, }) - if err := sendVerificationSMS(phone, code); err != nil { - log.Printf("failed to send verification SMS to %s: %v", phone, err) + if method == "email" { + if err := sendVerificationEmail(identifier, code); err != nil { + log.Printf("failed to send verification email to %s: %v", identifier, err) + } + } else { + if err := sendVerificationSMS(identifier, code); err != nil { + log.Printf("failed to send verification SMS to %s: %v", identifier, err) + } } pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{ "Step": "code", - "Phone": phone, + "Identifier": identifier, + "Method": method, "AuthEnabled": true, }) } func (s *Server) handleVerifyCode(w http.ResponseWriter, r *http.Request) { - phone := normalizePhone(r.FormValue("phone")) + identifier := strings.TrimSpace(r.FormValue("identifier")) + method := r.FormValue("method") code := strings.TrimSpace(r.FormValue("code")) vc, err := s.q.GetVerificationCode(r.Context(), db.GetVerificationCodeParams{ - Phone: phone, - Code: code, + Identifier: identifier, + Code: code, }) if err != nil { pageTmpl["login"].ExecuteTemplate(w, "layout", map[string]any{ "Step": "code", - "Phone": phone, + "Identifier": identifier, + "Method": method, "Error": "Invalid or expired code. Try again.", "AuthEnabled": true, }) @@ -127,9 +157,17 @@ func (s *Server) handleVerifyCode(w http.ResponseWriter, r *http.Request) { s.q.MarkVerificationCodeUsed(r.Context(), vc.ID) // Get or create user - user, err := s.q.GetUserByPhone(r.Context(), phone) - if err == sql.ErrNoRows { - user, err = s.q.CreateUser(r.Context(), phone) + var user db.User + if method == "email" { + user, err = s.q.GetUserByEmail(r.Context(), sql.NullString{String: identifier, Valid: true}) + if err == sql.ErrNoRows { + user, err = s.q.CreateUserByEmail(r.Context(), sql.NullString{String: identifier, Valid: true}) + } + } else { + user, err = s.q.GetUserByPhone(r.Context(), sql.NullString{String: identifier, Valid: true}) + if err == sql.ErrNoRows { + user, err = s.q.CreateUserByPhone(r.Context(), sql.NullString{String: identifier, Valid: true}) + } } if err != nil { log.Printf("user lookup/create: %v", err) @@ -278,3 +316,38 @@ func sendVerificationSMS(to, code string) error { } return nil } + +func sendVerificationEmail(to, code string) error { + apiKey := os.Getenv("RESEND_API_KEY") + if apiKey == "" { + log.Printf("RESEND_API_KEY not set — code for %s is: %s", to, code) + return nil + } + + fromAddr := os.Getenv("BBQ_FROM_EMAIL") + if fromAddr == "" { + fromAddr = "bbq " + } + + payload := map[string]any{ + "from": fromAddr, + "to": []string{to}, + "subject": fmt.Sprintf("Your login code: %s", code), + "html": fmt.Sprintf(`

Your login code

%s

This code expires in 10 minutes.

`, code), + } + + body, _ := json.Marshal(payload) + req, _ := http.NewRequest("POST", "https://api.resend.com/emails", bytes.NewReader(body)) + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return fmt.Errorf("resend API returned %d", resp.StatusCode) + } + return nil +} diff --git a/db/models.go b/db/models.go index 0ec6ab4..4fe1917 100644 --- a/db/models.go +++ b/db/models.go @@ -55,15 +55,16 @@ type Slot struct { type User struct { ID int64 - Phone string + Phone sql.NullString + Email sql.NullString Name string CreatedAt time.Time } type VerificationCode struct { - ID int64 - Phone string - Code string - ExpiresAt time.Time - Used int64 + ID int64 + Identifier string + Code string + ExpiresAt time.Time + Used int64 } diff --git a/db/queries.sql b/db/queries.sql index 08d92f3..831f452 100644 --- a/db/queries.sql +++ b/db/queries.sql @@ -72,9 +72,15 @@ SELECT COUNT(*) FROM rsvps WHERE event_id = ?; -- name: GetUserByPhone :one SELECT * FROM users WHERE phone = ?; --- name: CreateUser :one +-- name: GetUserByEmail :one +SELECT * FROM users WHERE email = ?; + +-- name: CreateUserByPhone :one INSERT INTO users (phone, name) VALUES (?, '') RETURNING *; +-- name: CreateUserByEmail :one +INSERT INTO users (email, name) VALUES (?, '') RETURNING *; + -- name: UpdateUserName :exec UPDATE users SET name = ? WHERE id = ?; @@ -91,11 +97,11 @@ DELETE FROM sessions WHERE token = ?; DELETE FROM sessions WHERE expires_at <= CURRENT_TIMESTAMP; -- name: CreateVerificationCode :exec -INSERT INTO verification_codes (phone, code, expires_at) VALUES (?, ?, ?); +INSERT INTO verification_codes (identifier, code, expires_at) VALUES (?, ?, ?); -- name: GetVerificationCode :one SELECT * FROM verification_codes -WHERE phone = ? AND code = ? AND used = 0 AND expires_at > CURRENT_TIMESTAMP +WHERE identifier = ? AND code = ? AND used = 0 AND expires_at > CURRENT_TIMESTAMP ORDER BY id DESC LIMIT 1; -- name: MarkVerificationCodeUsed :exec diff --git a/db/queries.sql.go b/db/queries.sql.go index 57c616a..48fb5e7 100644 --- a/db/queries.sql.go +++ b/db/queries.sql.go @@ -174,16 +174,34 @@ func (q *Queries) CreateSlot(ctx context.Context, arg CreateSlotParams) (Slot, e return i, err } -const createUser = `-- name: CreateUser :one -INSERT INTO users (phone, name) VALUES (?, '') RETURNING id, phone, name, created_at +const createUserByEmail = `-- name: CreateUserByEmail :one +INSERT INTO users (email, name) VALUES (?, '') RETURNING id, phone, email, name, created_at ` -func (q *Queries) CreateUser(ctx context.Context, phone string) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, phone) +func (q *Queries) CreateUserByEmail(ctx context.Context, email sql.NullString) (User, error) { + row := q.db.QueryRowContext(ctx, createUserByEmail, email) var i User err := row.Scan( &i.ID, &i.Phone, + &i.Email, + &i.Name, + &i.CreatedAt, + ) + return i, err +} + +const createUserByPhone = `-- name: CreateUserByPhone :one +INSERT INTO users (phone, name) VALUES (?, '') RETURNING id, phone, email, name, created_at +` + +func (q *Queries) CreateUserByPhone(ctx context.Context, phone sql.NullString) (User, error) { + row := q.db.QueryRowContext(ctx, createUserByPhone, phone) + var i User + err := row.Scan( + &i.ID, + &i.Phone, + &i.Email, &i.Name, &i.CreatedAt, ) @@ -191,17 +209,17 @@ func (q *Queries) CreateUser(ctx context.Context, phone string) (User, error) { } const createVerificationCode = `-- name: CreateVerificationCode :exec -INSERT INTO verification_codes (phone, code, expires_at) VALUES (?, ?, ?) +INSERT INTO verification_codes (identifier, code, expires_at) VALUES (?, ?, ?) ` type CreateVerificationCodeParams struct { - Phone string - Code string - ExpiresAt time.Time + Identifier string + Code string + ExpiresAt time.Time } func (q *Queries) CreateVerificationCode(ctx context.Context, arg CreateVerificationCodeParams) error { - _, err := q.db.ExecContext(ctx, createVerificationCode, arg.Phone, arg.Code, arg.ExpiresAt) + _, err := q.db.ExecContext(ctx, createVerificationCode, arg.Identifier, arg.Code, arg.ExpiresAt) return err } @@ -332,16 +350,34 @@ func (q *Queries) GetSlot(ctx context.Context, id int64) (Slot, error) { return i, err } -const getUserByPhone = `-- name: GetUserByPhone :one -SELECT id, phone, name, created_at FROM users WHERE phone = ? +const getUserByEmail = `-- name: GetUserByEmail :one +SELECT id, phone, email, name, created_at FROM users WHERE email = ? ` -func (q *Queries) GetUserByPhone(ctx context.Context, phone string) (User, error) { +func (q *Queries) GetUserByEmail(ctx context.Context, email sql.NullString) (User, error) { + row := q.db.QueryRowContext(ctx, getUserByEmail, email) + var i User + err := row.Scan( + &i.ID, + &i.Phone, + &i.Email, + &i.Name, + &i.CreatedAt, + ) + return i, err +} + +const getUserByPhone = `-- name: GetUserByPhone :one +SELECT id, phone, email, name, created_at FROM users WHERE phone = ? +` + +func (q *Queries) GetUserByPhone(ctx context.Context, phone sql.NullString) (User, error) { row := q.db.QueryRowContext(ctx, getUserByPhone, phone) var i User err := row.Scan( &i.ID, &i.Phone, + &i.Email, &i.Name, &i.CreatedAt, ) @@ -349,22 +385,22 @@ func (q *Queries) GetUserByPhone(ctx context.Context, phone string) (User, error } const getVerificationCode = `-- name: GetVerificationCode :one -SELECT id, phone, code, expires_at, used FROM verification_codes -WHERE phone = ? AND code = ? AND used = 0 AND expires_at > CURRENT_TIMESTAMP +SELECT id, identifier, code, expires_at, used FROM verification_codes +WHERE identifier = ? AND code = ? AND used = 0 AND expires_at > CURRENT_TIMESTAMP ORDER BY id DESC LIMIT 1 ` type GetVerificationCodeParams struct { - Phone string - Code string + Identifier string + Code string } func (q *Queries) GetVerificationCode(ctx context.Context, arg GetVerificationCodeParams) (VerificationCode, error) { - row := q.db.QueryRowContext(ctx, getVerificationCode, arg.Phone, arg.Code) + row := q.db.QueryRowContext(ctx, getVerificationCode, arg.Identifier, arg.Code) var i VerificationCode err := row.Scan( &i.ID, - &i.Phone, + &i.Identifier, &i.Code, &i.ExpiresAt, &i.Used, diff --git a/migrate.go b/migrate.go index 74d1fb7..547dc58 100644 --- a/migrate.go +++ b/migrate.go @@ -10,8 +10,11 @@ func runMigrations(database *sql.DB) { migrations := []string{ `ALTER TABLE events ADD COLUMN description TEXT DEFAULT ''`, `ALTER TABLE events ADD COLUMN user_id INTEGER REFERENCES users(id)`, - `ALTER TABLE users RENAME COLUMN email TO phone`, - `ALTER TABLE verification_codes RENAME COLUMN email TO phone`, + // Users may have email, phone, or both. Add whichever column is missing. + `ALTER TABLE users ADD COLUMN phone TEXT UNIQUE`, + `ALTER TABLE users ADD COLUMN email TEXT UNIQUE`, + // Verification codes use a generic identifier column. + `ALTER TABLE verification_codes ADD COLUMN identifier TEXT NOT NULL DEFAULT ''`, } for _, m := range migrations { _, err := database.Exec(m) diff --git a/schema.sql b/schema.sql index 4067858..6f37680 100644 --- a/schema.sql +++ b/schema.sql @@ -38,7 +38,8 @@ CREATE TABLE IF NOT EXISTS rsvps ( CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, - phone TEXT NOT NULL UNIQUE, + phone TEXT UNIQUE, + email TEXT UNIQUE, name TEXT NOT NULL DEFAULT '', created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ); @@ -51,7 +52,7 @@ CREATE TABLE IF NOT EXISTS sessions ( CREATE TABLE IF NOT EXISTS verification_codes ( id INTEGER PRIMARY KEY AUTOINCREMENT, - phone TEXT NOT NULL, + identifier TEXT NOT NULL, code TEXT NOT NULL, expires_at DATETIME NOT NULL, used INTEGER NOT NULL DEFAULT 0 @@ -61,4 +62,4 @@ CREATE INDEX IF NOT EXISTS idx_slots_event ON slots(event_id); CREATE INDEX IF NOT EXISTS idx_claims_slot ON claims(slot_id); CREATE INDEX IF NOT EXISTS idx_rsvps_event ON rsvps(event_id); CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id); -CREATE INDEX IF NOT EXISTS idx_verification_codes_phone ON verification_codes(phone); +CREATE INDEX IF NOT EXISTS idx_verification_codes_identifier ON verification_codes(identifier); diff --git a/templates/login.html b/templates/login.html index 615a4a8..9023c26 100644 --- a/templates/login.html +++ b/templates/login.html @@ -7,24 +7,25 @@
Account

Log in

- Enter your phone number to receive a login code. + Enter your email or phone number to receive a login code.

- {{if eq .Step "phone"}} + {{if eq .Step "identify"}}
- - + +
{{else}}
- + +

- Code sent to {{.Phone}} + Code sent to {{.Identifier}}

{{if .Error}}

{{.Error}}