Implement email magic-link authentication

internal/auth/ provides:
- TokenStore: 32-byte cryptographically random one-time tokens.
  Only the SHA-256 hash is persisted (so a DB leak doesn't grant
  active sessions). Comparison uses subtle.ConstantTimeCompare.
  Single-use is enforced via UPDATE ... WHERE used_at IS NULL.
- Signer: HS256 JWTs with 24h lifetime, jwt.WithValidMethods to
  reject alg=none and other downgrade attacks.
- LogMailer (dev) and SMTPMailer (prod via net/smtp) behind a
  Mailer interface.
- RateLimiter: DB-backed fixed window per email; default 5 per
  15 min for the magic-link flow.
- Service: orchestrates RequestLogin (auto-creates user on first
  login, generates token, emails magic link) and Verify (consumes
  token, updates last_login, issues JWT).
- Handlers: POST /auth/login and GET/POST /auth/verify.
  HandleLogin returns 202 even on validation failure to avoid
  account enumeration; rate-limit hits surface as 429.

Schema additions: magic_tokens (with FK + cascade) and
login_attempts. UserStore.SetStoragePath added for completeness.

Tests cover: token issue/consume, single-use, expiry, rate limit,
JWT round-trip, alg=none rejection, signature tampering, purge,
HTTP handlers (login + verify, missing/invalid token paths).

Closes #9.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-28 22:16:25 +02:00
parent 0924e3cee9
commit d9f3574913
11 changed files with 812 additions and 0 deletions

1
go.mod
View File

@@ -13,6 +13,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gdamore/encoding v1.0.1 // indirect
github.com/gdamore/tcell/v2 v2.8.1 // indirect
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
github.com/junegunn/go-shellwords v0.0.0-20240813092932-a62c48c52e97 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect

2
go.sum
View File

@@ -6,6 +6,8 @@ github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uh
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
github.com/gdamore/tcell/v2 v2.8.1 h1:KPNxyqclpWpWQlPLx6Xui1pMk8S+7+R37h3g07997NU=
github.com/gdamore/tcell/v2 v2.8.1/go.mod h1:bj8ori1BG3OYMjmb3IklZVWfZUJ1UBQt9JXrOCOhGWw=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=

220
internal/auth/auth_test.go Normal file
View File

@@ -0,0 +1,220 @@
package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"time"
"git.librete.ch/public/librenotes/internal/storage"
)
func newTestService(t *testing.T) (*Service, *bytes.Buffer) {
t.Helper()
db, err := storage.Open(filepath.Join(t.TempDir(), "auth.db"))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
users := storage.NewUserStore(db)
tokens := NewTokenStore(db)
limiter := NewRateLimiter(db, time.Minute, 3)
signer := NewSigner([]byte("test-secret-32-bytes-of-keymaterial!!"))
mailbox := &bytes.Buffer{}
svc, err := NewService(Config{
Users: users,
Tokens: tokens,
Limiter: limiter,
Mailer: LogMailer{W: mailbox},
Signer: signer,
BaseURL: "https://test.example",
DataDir: "/tmp/data",
})
if err != nil {
t.Fatal(err)
}
return svc, mailbox
}
func extractToken(t *testing.T, mailbox string) string {
t.Helper()
idx := strings.Index(mailbox, "token=")
if idx < 0 {
t.Fatalf("no token in mailbox: %q", mailbox)
}
rest := mailbox[idx+len("token="):]
end := strings.IndexAny(rest, " \n")
if end < 0 {
end = len(rest)
}
return strings.TrimSpace(rest[:end])
}
func TestRequestAndVerify(t *testing.T) {
svc, mailbox := newTestService(t)
ctx := context.Background()
if err := svc.RequestLogin(ctx, "Bob@example.com"); err != nil {
t.Fatalf("request: %v", err)
}
tok := extractToken(t, mailbox.String())
res, err := svc.Verify(ctx, tok)
if err != nil {
t.Fatalf("verify: %v", err)
}
if res.JWT == "" || res.User.Email != "bob@example.com" {
t.Errorf("bad result: %+v", res)
}
// Single use.
if _, err := svc.Verify(ctx, tok); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken on reuse, got %v", err)
}
}
func TestVerifyExpiredToken(t *testing.T) {
svc, mailbox := newTestService(t)
ctx := context.Background()
if err := svc.RequestLogin(ctx, "exp@example.com"); err != nil {
t.Fatal(err)
}
tok := extractToken(t, mailbox.String())
// Force expiry by rewriting expires_at directly.
if _, err := svc.tokens.db.ExecContext(ctx,
`UPDATE magic_tokens SET expires_at = 0`); err != nil {
t.Fatal(err)
}
if _, err := svc.Verify(ctx, tok); !errors.Is(err, ErrInvalidToken) {
t.Errorf("expected ErrInvalidToken, got %v", err)
}
}
func TestRateLimit(t *testing.T) {
svc, _ := newTestService(t)
ctx := context.Background()
for i := 0; i < 3; i++ {
if err := svc.RequestLogin(ctx, "rl@example.com"); err != nil {
t.Fatalf("attempt %d: %v", i, err)
}
}
if err := svc.RequestLogin(ctx, "rl@example.com"); !errors.Is(err, ErrRateLimited) {
t.Errorf("expected ErrRateLimited, got %v", err)
}
}
func TestInvalidEmail(t *testing.T) {
svc, _ := newTestService(t)
cases := []string{"", "noatsign", "@nohost", "no@host"}
for _, c := range cases {
if err := svc.RequestLogin(context.Background(), c); err == nil {
t.Errorf("expected error for %q", c)
}
}
}
func TestJWTRoundTrip(t *testing.T) {
s := NewSigner([]byte("k0123456789012345678901234567890"))
tok, err := s.Issue("u-1", "u@example.com")
if err != nil {
t.Fatal(err)
}
claims, err := s.Verify(tok)
if err != nil {
t.Fatal(err)
}
if claims.UserID != "u-1" || claims.Email != "u@example.com" {
t.Errorf("bad claims: %+v", claims)
}
}
func TestJWTRejectsForgedAlg(t *testing.T) {
s := NewSigner([]byte("k0123456789012345678901234567890"))
// alg=none token
bad := "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJzdWIiOiJ1LTEifQ."
if _, err := s.Verify(bad); err == nil {
t.Errorf("expected error on alg=none")
}
}
func TestJWTRejectsTamperedSignature(t *testing.T) {
s := NewSigner([]byte("k0123456789012345678901234567890"))
tok, _ := s.Issue("u-1", "u@example.com")
tampered := tok[:len(tok)-2] + "AA"
if _, err := s.Verify(tampered); err == nil {
t.Errorf("expected error on tampered token")
}
}
func TestPurgeExpired(t *testing.T) {
svc, _ := newTestService(t)
ctx := context.Background()
_ = svc.RequestLogin(ctx, "purge@example.com")
if _, err := svc.tokens.db.ExecContext(ctx,
`UPDATE magic_tokens SET expires_at = 0`); err != nil {
t.Fatal(err)
}
if err := svc.tokens.PurgeExpired(ctx, time.Minute); err != nil {
t.Fatal(err)
}
var n int
_ = svc.tokens.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM magic_tokens`).Scan(&n)
if n != 0 {
t.Errorf("expected 0 rows, got %d", n)
}
}
func TestHandleLoginAndVerify(t *testing.T) {
svc, mailbox := newTestService(t)
h := Handlers{Service: svc}
body := strings.NewReader(`{"email":"http@example.com"}`)
req := httptest.NewRequest(http.MethodPost, "/auth/login", body)
rec := httptest.NewRecorder()
h.HandleLogin(rec, req)
if rec.Code != http.StatusAccepted {
t.Fatalf("login status: %d body=%s", rec.Code, rec.Body)
}
tok := extractToken(t, mailbox.String())
req2 := httptest.NewRequest(http.MethodGet, "/auth/verify?token="+tok, nil)
rec2 := httptest.NewRecorder()
h.HandleVerify(rec2, req2)
if rec2.Code != http.StatusOK {
t.Fatalf("verify status: %d body=%s", rec2.Code, rec2.Body)
}
var resp VerifyResponse
if err := json.NewDecoder(rec2.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.JWT == "" || resp.UserID == "" {
t.Errorf("missing fields: %+v", resp)
}
}
func TestHandleVerifyMissingToken(t *testing.T) {
svc, _ := newTestService(t)
h := Handlers{Service: svc}
req := httptest.NewRequest(http.MethodGet, "/auth/verify", nil)
rec := httptest.NewRecorder()
h.HandleVerify(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("got %d", rec.Code)
}
}
func TestHandleVerifyInvalidToken(t *testing.T) {
svc, _ := newTestService(t)
h := Handlers{Service: svc}
req := httptest.NewRequest(http.MethodGet, "/auth/verify?token=deadbeef", nil)
rec := httptest.NewRecorder()
h.HandleVerify(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("got %d", rec.Code)
}
}

83
internal/auth/handlers.go Normal file
View File

@@ -0,0 +1,83 @@
package auth
import (
"encoding/json"
"errors"
"net/http"
)
// Handlers exposes HTTP handlers for the magic-link login flow.
type Handlers struct{ Service *Service }
// LoginRequest is the JSON body for POST /auth/login.
type LoginRequest struct {
Email string `json:"email"`
}
// HandleLogin accepts an email and triggers a magic-link send. It
// always returns 202 even when the email is unknown or rate-limited
// for clients we want to expose; we return distinguishable errors only
// for malformed input. This avoids account-enumeration leaks.
func (h Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req LoginRequest
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1024)).Decode(&req); err != nil {
http.Error(w, "invalid body", http.StatusBadRequest)
return
}
err := h.Service.RequestLogin(r.Context(), req.Email)
switch {
case err == nil:
case errors.Is(err, ErrRateLimited):
// Surface rate limiting as 429 — the email is known to the
// client (they sent it) so we don't leak account existence.
http.Error(w, "rate limited", http.StatusTooManyRequests)
return
default:
// Any other failure is internal; mask details.
// In particular, validation failures look the same as success
// to the client; we still log the actual error server-side.
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte(`{"status":"sent"}`))
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte(`{"status":"sent"}`))
}
// VerifyResponse is the JSON body returned by GET /auth/verify on success.
type VerifyResponse struct {
JWT string `json:"jwt"`
UserID string `json:"user_id"`
Email string `json:"email"`
Expires int64 `json:"expires_at"`
}
// HandleVerify validates a magic-link token and issues a session JWT.
func (h Handlers) HandleVerify(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet && r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
token := r.URL.Query().Get("token")
if token == "" {
http.Error(w, "missing token", http.StatusBadRequest)
return
}
res, err := h.Service.Verify(r.Context(), token)
if err != nil {
http.Error(w, "invalid or expired token", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(VerifyResponse{
JWT: res.JWT,
UserID: res.User.ID,
Email: res.User.Email,
Expires: nowFn().Add(SessionLifetime).Unix(),
})
}

71
internal/auth/jwt.go Normal file
View File

@@ -0,0 +1,71 @@
package auth
import (
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
// SessionLifetime is how long an issued JWT remains valid.
const SessionLifetime = 24 * time.Hour
// ErrInvalidJWT is returned when a JWT fails parsing, signature
// verification, or claim validation.
var ErrInvalidJWT = errors.New("invalid jwt")
// Claims is the JWT payload we sign for authenticated sessions.
type Claims struct {
UserID string `json:"sub"`
Email string `json:"email"`
jwt.RegisteredClaims
}
// Signer issues and verifies session JWTs using HS256.
type Signer struct {
secret []byte
}
// NewSigner builds a Signer from a shared secret (>= 32 bytes recommended).
func NewSigner(secret []byte) *Signer { return &Signer{secret: secret} }
// Issue creates a signed JWT for the given user.
func (s *Signer) Issue(userID, email string) (string, error) {
now := time.Now().UTC()
claims := Claims{
UserID: userID,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(SessionLifetime)),
NotBefore: jwt.NewNumericDate(now),
Issuer: "librenotes",
Subject: userID,
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := tok.SignedString(s.secret)
if err != nil {
return "", fmt.Errorf("sign: %w", err)
}
return signed, nil
}
// Verify parses and validates a JWT, returning the claims on success.
func (s *Signer) Verify(token string) (*Claims, error) {
claims := &Claims{}
parsed, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return s.secret, nil
}, jwt.WithValidMethods([]string{"HS256"}))
if err != nil || !parsed.Valid {
return nil, ErrInvalidJWT
}
if claims.UserID == "" {
return nil, ErrInvalidJWT
}
return claims, nil
}

62
internal/auth/mailer.go Normal file
View File

@@ -0,0 +1,62 @@
package auth
import (
"context"
"fmt"
"io"
"net/smtp"
"strings"
)
// Mailer sends magic-link emails. The interface keeps the auth flow
// decoupled from any concrete delivery backend.
type Mailer interface {
SendMagicLink(ctx context.Context, email, link string) error
}
// LogMailer writes the magic link to a writer instead of sending an
// email. Useful for development and tests.
type LogMailer struct{ W io.Writer }
func (m LogMailer) SendMagicLink(_ context.Context, email, link string) error {
if m.W == nil {
return nil
}
_, err := fmt.Fprintf(m.W, "magic link for %s: %s\n", email, link)
return err
}
// SMTPMailer delivers via plain SMTP with optional auth. TLS is the
// caller's responsibility (use submission port 587 with STARTTLS or 465
// with implicit TLS — the stdlib smtp package handles STARTTLS via
// SendMail when the server supports it).
type SMTPMailer struct {
Host string
Port string
Username string
Password string
From string
}
func (m SMTPMailer) SendMagicLink(_ context.Context, to, link string) error {
if m.Host == "" || m.From == "" {
return fmt.Errorf("smtp not configured")
}
addr := m.Host + ":" + m.Port
var auth smtp.Auth
if m.Username != "" {
auth = smtp.PlainAuth("", m.Username, m.Password, m.Host)
}
subject := "Your librenotes login link"
body := fmt.Sprintf("Click to sign in to librenotes:\n\n%s\n\nThis link is valid for 15 minutes and can be used once.\n", link)
msg := strings.Join([]string{
"From: " + m.From,
"To: " + to,
"Subject: " + subject,
"MIME-Version: 1.0",
"Content-Type: text/plain; charset=utf-8",
"",
body,
}, "\r\n")
return smtp.SendMail(addr, auth, m.From, []string{to}, []byte(msg))
}

View File

@@ -0,0 +1,61 @@
package auth
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
)
// ErrRateLimited indicates that too many requests have been made for a
// given key within the configured window.
var ErrRateLimited = errors.New("rate limited")
// RateLimiter enforces a fixed-window rate limit per email using the
// login_attempts table. The DB-backed approach survives restarts and
// works across multiple processes sharing the same SQLite file.
type RateLimiter struct {
db *sql.DB
window time.Duration
max int
}
// NewRateLimiter creates a limiter with the given window and maximum
// attempts. Default for magic-link login: 5 per 15 min.
func NewRateLimiter(db *sql.DB, window time.Duration, max int) *RateLimiter {
return &RateLimiter{db: db, window: window, max: max}
}
// Check records an attempt for email and returns ErrRateLimited if the
// number of attempts within the window exceeds max.
func (r *RateLimiter) Check(ctx context.Context, email string) error {
now := time.Now().UTC()
cutoff := now.Add(-r.window).Unix()
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin: %w", err)
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx,
`DELETE FROM login_attempts WHERE created_at < ?`, cutoff); err != nil {
return fmt.Errorf("prune: %w", err)
}
var count int
if err := tx.QueryRowContext(ctx,
`SELECT COUNT(*) FROM login_attempts WHERE email = ? AND created_at >= ?`,
email, cutoff).Scan(&count); err != nil {
return fmt.Errorf("count: %w", err)
}
if count >= r.max {
return ErrRateLimited
}
if _, err := tx.ExecContext(ctx,
`INSERT INTO login_attempts (email, created_at) VALUES (?, ?)`,
email, now.Unix()); err != nil {
return fmt.Errorf("insert: %w", err)
}
return tx.Commit()
}

147
internal/auth/service.go Normal file
View File

@@ -0,0 +1,147 @@
package auth
import (
"context"
"errors"
"fmt"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"git.librete.ch/public/librenotes/internal/storage"
)
// nowFn is overridable in tests.
var nowFn = func() time.Time { return time.Now().UTC() }
// Service orchestrates the magic-link login flow: request a link by
// email and verify a returned link to issue a JWT session.
type Service struct {
users *storage.UserStore
tokens *TokenStore
limiter *RateLimiter
mailer Mailer
signer *Signer
baseURL string
dataDir string
}
// Config bundles dependencies for NewService.
type Config struct {
Users *storage.UserStore
Tokens *TokenStore
Limiter *RateLimiter
Mailer Mailer
Signer *Signer
// BaseURL is the public origin used to build verification links,
// e.g. "https://librenot.es".
BaseURL string
// DataDir is the parent directory under which per-user note
// directories are created on first login.
DataDir string
}
// NewService validates and assembles a Service.
func NewService(c Config) (*Service, error) {
if c.Users == nil || c.Tokens == nil || c.Limiter == nil || c.Mailer == nil || c.Signer == nil {
return nil, fmt.Errorf("auth: missing dependency")
}
if c.BaseURL == "" {
return nil, fmt.Errorf("auth: BaseURL required")
}
if c.DataDir == "" {
return nil, fmt.Errorf("auth: DataDir required")
}
return &Service{
users: c.Users,
tokens: c.Tokens,
limiter: c.Limiter,
mailer: c.Mailer,
signer: c.Signer,
baseURL: strings.TrimRight(c.BaseURL, "/"),
dataDir: c.DataDir,
}, nil
}
// RequestLogin generates a magic link and emails it. The user is
// auto-created on first login. Returns ErrRateLimited if the email has
// exceeded the limiter window.
func (s *Service) RequestLogin(ctx context.Context, email string) error {
email = strings.ToLower(strings.TrimSpace(email))
if !looksLikeEmail(email) {
return fmt.Errorf("invalid email")
}
if err := s.limiter.Check(ctx, email); err != nil {
return err
}
user, err := s.users.GetByEmail(ctx, email)
if errors.Is(err, storage.ErrNotFound) {
id := uuid.NewString()
path := filepath.Join(s.dataDir, id)
created, cerr := s.users.Create(ctx, storage.User{
ID: id,
Email: email,
StoragePath: path,
})
if cerr != nil {
return fmt.Errorf("create user: %w", cerr)
}
user = created
} else if err != nil {
return fmt.Errorf("lookup user: %w", err)
}
plaintext, err := s.tokens.Issue(ctx, user.ID, email)
if err != nil {
return err
}
link := s.baseURL + "/auth/verify?token=" + url.QueryEscape(plaintext)
if err := s.mailer.SendMagicLink(ctx, email, link); err != nil {
return fmt.Errorf("send mail: %w", err)
}
return nil
}
// VerifyResult holds the outcome of a successful magic-link verification.
type VerifyResult struct {
JWT string
User storage.User
}
// Verify consumes a magic-link token and returns a signed session JWT.
func (s *Service) Verify(ctx context.Context, token string) (VerifyResult, error) {
userID, err := s.tokens.Consume(ctx, token)
if err != nil {
return VerifyResult{}, err
}
user, err := s.users.GetByID(ctx, userID)
if err != nil {
return VerifyResult{}, fmt.Errorf("load user: %w", err)
}
now := nowFn()
if err := s.users.UpdateLastLogin(ctx, user.ID, now); err != nil {
return VerifyResult{}, fmt.Errorf("update last login: %w", err)
}
jwtStr, err := s.signer.Issue(user.ID, user.Email)
if err != nil {
return VerifyResult{}, err
}
return VerifyResult{JWT: jwtStr, User: user}, nil
}
// looksLikeEmail does a minimal sanity check; full validation is left
// to the SMTP server. We just want to reject obvious garbage.
func looksLikeEmail(s string) bool {
at := strings.IndexByte(s, '@')
if at <= 0 || at == len(s)-1 {
return false
}
if strings.ContainsAny(s, " \t\r\n") {
return false
}
return strings.IndexByte(s[at+1:], '.') >= 0
}

137
internal/auth/tokens.go Normal file
View File

@@ -0,0 +1,137 @@
// Package auth implements magic-link authentication and JWT issuance
// for the librenotes multi-tenant backend.
package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"time"
)
// TokenLifetime is how long a magic link is valid after creation.
const TokenLifetime = 15 * time.Minute
// ErrInvalidToken is returned for unknown, expired, or already-used tokens.
var ErrInvalidToken = errors.New("invalid or expired token")
// MagicToken records a single magic-link issuance. The plaintext token
// is never stored — only its SHA-256 hash. The plaintext lives only in
// the email sent to the user.
type MagicToken struct {
UserID string
Email string
CreatedAt time.Time
ExpiresAt time.Time
}
// TokenStore persists magic-link token hashes.
type TokenStore struct {
db *sql.DB
}
// NewTokenStore wraps a database handle.
func NewTokenStore(db *sql.DB) *TokenStore { return &TokenStore{db: db} }
// Issue generates a new random token, stores its hash bound to userID,
// and returns the plaintext token (caller emails it to the user).
func (s *TokenStore) Issue(ctx context.Context, userID, email string) (string, error) {
plaintext, err := randomToken()
if err != nil {
return "", err
}
now := time.Now().UTC()
hash := hashToken(plaintext)
_, err = s.db.ExecContext(ctx,
`INSERT INTO magic_tokens (token_hash, user_id, email, created_at, expires_at) VALUES (?, ?, ?, ?, ?)`,
hash, userID, email, now.Unix(), now.Add(TokenLifetime).Unix())
if err != nil {
return "", fmt.Errorf("insert magic token: %w", err)
}
return plaintext, nil
}
// Consume validates the plaintext token. If valid, it marks the token
// used (single-use) and returns the bound user ID. Comparison goes
// through subtle.ConstantTimeCompare; the hash lookup gives us O(1)
// retrieval without leaking timing for the row scan, and the constant-
// time compare protects against any residual differences.
func (s *TokenStore) Consume(ctx context.Context, plaintext string) (userID string, err error) {
hash := hashToken(plaintext)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return "", fmt.Errorf("begin: %w", err)
}
defer tx.Rollback()
var (
storedHash string
uid string
expires int64
used sql.NullInt64
)
err = tx.QueryRowContext(ctx,
`SELECT token_hash, user_id, expires_at, used_at FROM magic_tokens WHERE token_hash = ?`,
hash).Scan(&storedHash, &uid, &expires, &used)
if errors.Is(err, sql.ErrNoRows) {
return "", ErrInvalidToken
}
if err != nil {
return "", fmt.Errorf("select magic token: %w", err)
}
if subtle.ConstantTimeCompare([]byte(storedHash), []byte(hash)) != 1 {
return "", ErrInvalidToken
}
if used.Valid {
return "", ErrInvalidToken
}
if time.Now().UTC().Unix() > expires {
return "", ErrInvalidToken
}
res, err := tx.ExecContext(ctx,
`UPDATE magic_tokens SET used_at = ? WHERE token_hash = ? AND used_at IS NULL`,
time.Now().UTC().Unix(), hash)
if err != nil {
return "", fmt.Errorf("mark used: %w", err)
}
n, _ := res.RowsAffected()
if n != 1 {
return "", ErrInvalidToken
}
if err := tx.Commit(); err != nil {
return "", fmt.Errorf("commit: %w", err)
}
return uid, nil
}
// PurgeExpired deletes all expired or used tokens older than retention.
// Callers can run this on a timer to keep the table small.
func (s *TokenStore) PurgeExpired(ctx context.Context, retention time.Duration) error {
cutoff := time.Now().UTC().Add(-retention).Unix()
_, err := s.db.ExecContext(ctx,
`DELETE FROM magic_tokens WHERE expires_at < ? OR (used_at IS NOT NULL AND used_at < ?)`,
cutoff, cutoff)
if err != nil {
return fmt.Errorf("purge: %w", err)
}
return nil
}
func randomToken() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("rand: %w", err)
}
return hex.EncodeToString(buf), nil
}
func hashToken(plaintext string) string {
sum := sha256.Sum256([]byte(plaintext))
return hex.EncodeToString(sum[:])
}

View File

@@ -47,4 +47,18 @@ var migrations = []string{
storage_path TEXT NOT NULL
)`,
`CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)`,
`CREATE TABLE IF NOT EXISTS magic_tokens (
token_hash TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
email TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
used_at INTEGER
)`,
`CREATE INDEX IF NOT EXISTS idx_magic_tokens_expires ON magic_tokens(expires_at)`,
`CREATE TABLE IF NOT EXISTS login_attempts (
email TEXT NOT NULL,
created_at INTEGER NOT NULL
)`,
`CREATE INDEX IF NOT EXISTS idx_login_attempts ON login_attempts(email, created_at)`,
}

View File

@@ -79,6 +79,20 @@ func (s *UserStore) GetByEmail(ctx context.Context, email string) (User, error)
return s.scanOne(ctx, `SELECT id, email, created_at, last_login_at, storage_path FROM users WHERE email = ?`, normaliseEmail(email))
}
// SetStoragePath updates the per-user storage path. Used during the
// first-login auto-create flow once the UUID is known.
func (s *UserStore) SetStoragePath(ctx context.Context, id, path string) error {
res, err := s.db.ExecContext(ctx, `UPDATE users SET storage_path = ? WHERE id = ?`, path, id)
if err != nil {
return fmt.Errorf("update storage_path: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}
// UpdateLastLogin records a successful login at the given instant.
func (s *UserStore) UpdateLastLogin(ctx context.Context, id string, at time.Time) error {
res, err := s.db.ExecContext(ctx, `UPDATE users SET last_login_at = ? WHERE id = ?`, at.Unix(), id)