Implement email magic-link authentication
internal/auth/ provides: - TokenStore: 32-byte cryptographically random one-time tokens. Only the SHA-256 hash is persisted (so a DB leak doesn't grant active sessions). Comparison uses subtle.ConstantTimeCompare. Single-use is enforced via UPDATE ... WHERE used_at IS NULL. - Signer: HS256 JWTs with 24h lifetime, jwt.WithValidMethods to reject alg=none and other downgrade attacks. - LogMailer (dev) and SMTPMailer (prod via net/smtp) behind a Mailer interface. - RateLimiter: DB-backed fixed window per email; default 5 per 15 min for the magic-link flow. - Service: orchestrates RequestLogin (auto-creates user on first login, generates token, emails magic link) and Verify (consumes token, updates last_login, issues JWT). - Handlers: POST /auth/login and GET/POST /auth/verify. HandleLogin returns 202 even on validation failure to avoid account enumeration; rate-limit hits surface as 429. Schema additions: magic_tokens (with FK + cascade) and login_attempts. UserStore.SetStoragePath added for completeness. Tests cover: token issue/consume, single-use, expiry, rate limit, JWT round-trip, alg=none rejection, signature tampering, purge, HTTP handlers (login + verify, missing/invalid token paths). Closes #9. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
1
go.mod
1
go.mod
@@ -13,6 +13,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/gdamore/tcell/v2 v2.8.1 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/junegunn/go-shellwords v0.0.0-20240813092932-a62c48c52e97 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -6,6 +6,8 @@ github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uh
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.8.1 h1:KPNxyqclpWpWQlPLx6Xui1pMk8S+7+R37h3g07997NU=
|
||||
github.com/gdamore/tcell/v2 v2.8.1/go.mod h1:bj8ori1BG3OYMjmb3IklZVWfZUJ1UBQt9JXrOCOhGWw=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
|
||||
220
internal/auth/auth_test.go
Normal file
220
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.librete.ch/public/librenotes/internal/storage"
|
||||
)
|
||||
|
||||
func newTestService(t *testing.T) (*Service, *bytes.Buffer) {
|
||||
t.Helper()
|
||||
db, err := storage.Open(filepath.Join(t.TempDir(), "auth.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
users := storage.NewUserStore(db)
|
||||
tokens := NewTokenStore(db)
|
||||
limiter := NewRateLimiter(db, time.Minute, 3)
|
||||
signer := NewSigner([]byte("test-secret-32-bytes-of-keymaterial!!"))
|
||||
mailbox := &bytes.Buffer{}
|
||||
svc, err := NewService(Config{
|
||||
Users: users,
|
||||
Tokens: tokens,
|
||||
Limiter: limiter,
|
||||
Mailer: LogMailer{W: mailbox},
|
||||
Signer: signer,
|
||||
BaseURL: "https://test.example",
|
||||
DataDir: "/tmp/data",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return svc, mailbox
|
||||
}
|
||||
|
||||
func extractToken(t *testing.T, mailbox string) string {
|
||||
t.Helper()
|
||||
idx := strings.Index(mailbox, "token=")
|
||||
if idx < 0 {
|
||||
t.Fatalf("no token in mailbox: %q", mailbox)
|
||||
}
|
||||
rest := mailbox[idx+len("token="):]
|
||||
end := strings.IndexAny(rest, " \n")
|
||||
if end < 0 {
|
||||
end = len(rest)
|
||||
}
|
||||
return strings.TrimSpace(rest[:end])
|
||||
}
|
||||
|
||||
func TestRequestAndVerify(t *testing.T) {
|
||||
svc, mailbox := newTestService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
if err := svc.RequestLogin(ctx, "Bob@example.com"); err != nil {
|
||||
t.Fatalf("request: %v", err)
|
||||
}
|
||||
tok := extractToken(t, mailbox.String())
|
||||
res, err := svc.Verify(ctx, tok)
|
||||
if err != nil {
|
||||
t.Fatalf("verify: %v", err)
|
||||
}
|
||||
if res.JWT == "" || res.User.Email != "bob@example.com" {
|
||||
t.Errorf("bad result: %+v", res)
|
||||
}
|
||||
// Single use.
|
||||
if _, err := svc.Verify(ctx, tok); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken on reuse, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyExpiredToken(t *testing.T) {
|
||||
svc, mailbox := newTestService(t)
|
||||
ctx := context.Background()
|
||||
if err := svc.RequestLogin(ctx, "exp@example.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tok := extractToken(t, mailbox.String())
|
||||
// Force expiry by rewriting expires_at directly.
|
||||
if _, err := svc.tokens.db.ExecContext(ctx,
|
||||
`UPDATE magic_tokens SET expires_at = 0`); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := svc.Verify(ctx, tok); !errors.Is(err, ErrInvalidToken) {
|
||||
t.Errorf("expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
svc, _ := newTestService(t)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := svc.RequestLogin(ctx, "rl@example.com"); err != nil {
|
||||
t.Fatalf("attempt %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
if err := svc.RequestLogin(ctx, "rl@example.com"); !errors.Is(err, ErrRateLimited) {
|
||||
t.Errorf("expected ErrRateLimited, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidEmail(t *testing.T) {
|
||||
svc, _ := newTestService(t)
|
||||
cases := []string{"", "noatsign", "@nohost", "no@host"}
|
||||
for _, c := range cases {
|
||||
if err := svc.RequestLogin(context.Background(), c); err == nil {
|
||||
t.Errorf("expected error for %q", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTRoundTrip(t *testing.T) {
|
||||
s := NewSigner([]byte("k0123456789012345678901234567890"))
|
||||
tok, err := s.Issue("u-1", "u@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
claims, err := s.Verify(tok)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if claims.UserID != "u-1" || claims.Email != "u@example.com" {
|
||||
t.Errorf("bad claims: %+v", claims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTRejectsForgedAlg(t *testing.T) {
|
||||
s := NewSigner([]byte("k0123456789012345678901234567890"))
|
||||
// alg=none token
|
||||
bad := "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJzdWIiOiJ1LTEifQ."
|
||||
if _, err := s.Verify(bad); err == nil {
|
||||
t.Errorf("expected error on alg=none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTRejectsTamperedSignature(t *testing.T) {
|
||||
s := NewSigner([]byte("k0123456789012345678901234567890"))
|
||||
tok, _ := s.Issue("u-1", "u@example.com")
|
||||
tampered := tok[:len(tok)-2] + "AA"
|
||||
if _, err := s.Verify(tampered); err == nil {
|
||||
t.Errorf("expected error on tampered token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPurgeExpired(t *testing.T) {
|
||||
svc, _ := newTestService(t)
|
||||
ctx := context.Background()
|
||||
_ = svc.RequestLogin(ctx, "purge@example.com")
|
||||
if _, err := svc.tokens.db.ExecContext(ctx,
|
||||
`UPDATE magic_tokens SET expires_at = 0`); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := svc.tokens.PurgeExpired(ctx, time.Minute); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var n int
|
||||
_ = svc.tokens.db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM magic_tokens`).Scan(&n)
|
||||
if n != 0 {
|
||||
t.Errorf("expected 0 rows, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleLoginAndVerify(t *testing.T) {
|
||||
svc, mailbox := newTestService(t)
|
||||
h := Handlers{Service: svc}
|
||||
|
||||
body := strings.NewReader(`{"email":"http@example.com"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth/login", body)
|
||||
rec := httptest.NewRecorder()
|
||||
h.HandleLogin(rec, req)
|
||||
if rec.Code != http.StatusAccepted {
|
||||
t.Fatalf("login status: %d body=%s", rec.Code, rec.Body)
|
||||
}
|
||||
|
||||
tok := extractToken(t, mailbox.String())
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/auth/verify?token="+tok, nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
h.HandleVerify(rec2, req2)
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("verify status: %d body=%s", rec2.Code, rec2.Body)
|
||||
}
|
||||
var resp VerifyResponse
|
||||
if err := json.NewDecoder(rec2.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.JWT == "" || resp.UserID == "" {
|
||||
t.Errorf("missing fields: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVerifyMissingToken(t *testing.T) {
|
||||
svc, _ := newTestService(t)
|
||||
h := Handlers{Service: svc}
|
||||
req := httptest.NewRequest(http.MethodGet, "/auth/verify", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.HandleVerify(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVerifyInvalidToken(t *testing.T) {
|
||||
svc, _ := newTestService(t)
|
||||
h := Handlers{Service: svc}
|
||||
req := httptest.NewRequest(http.MethodGet, "/auth/verify?token=deadbeef", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.HandleVerify(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Errorf("got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
83
internal/auth/handlers.go
Normal file
83
internal/auth/handlers.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Handlers exposes HTTP handlers for the magic-link login flow.
|
||||
type Handlers struct{ Service *Service }
|
||||
|
||||
// LoginRequest is the JSON body for POST /auth/login.
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// HandleLogin accepts an email and triggers a magic-link send. It
|
||||
// always returns 202 even when the email is unknown or rate-limited
|
||||
// for clients we want to expose; we return distinguishable errors only
|
||||
// for malformed input. This avoids account-enumeration leaks.
|
||||
func (h Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1024)).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
err := h.Service.RequestLogin(r.Context(), req.Email)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, ErrRateLimited):
|
||||
// Surface rate limiting as 429 — the email is known to the
|
||||
// client (they sent it) so we don't leak account existence.
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
default:
|
||||
// Any other failure is internal; mask details.
|
||||
// In particular, validation failures look the same as success
|
||||
// to the client; we still log the actual error server-side.
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
_, _ = w.Write([]byte(`{"status":"sent"}`))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
_, _ = w.Write([]byte(`{"status":"sent"}`))
|
||||
}
|
||||
|
||||
// VerifyResponse is the JSON body returned by GET /auth/verify on success.
|
||||
type VerifyResponse struct {
|
||||
JWT string `json:"jwt"`
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Expires int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
// HandleVerify validates a magic-link token and issues a session JWT.
|
||||
func (h Handlers) HandleVerify(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
token := r.URL.Query().Get("token")
|
||||
if token == "" {
|
||||
http.Error(w, "missing token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
res, err := h.Service.Verify(r.Context(), token)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(VerifyResponse{
|
||||
JWT: res.JWT,
|
||||
UserID: res.User.ID,
|
||||
Email: res.User.Email,
|
||||
Expires: nowFn().Add(SessionLifetime).Unix(),
|
||||
})
|
||||
}
|
||||
71
internal/auth/jwt.go
Normal file
71
internal/auth/jwt.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// SessionLifetime is how long an issued JWT remains valid.
|
||||
const SessionLifetime = 24 * time.Hour
|
||||
|
||||
// ErrInvalidJWT is returned when a JWT fails parsing, signature
|
||||
// verification, or claim validation.
|
||||
var ErrInvalidJWT = errors.New("invalid jwt")
|
||||
|
||||
// Claims is the JWT payload we sign for authenticated sessions.
|
||||
type Claims struct {
|
||||
UserID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// Signer issues and verifies session JWTs using HS256.
|
||||
type Signer struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
// NewSigner builds a Signer from a shared secret (>= 32 bytes recommended).
|
||||
func NewSigner(secret []byte) *Signer { return &Signer{secret: secret} }
|
||||
|
||||
// Issue creates a signed JWT for the given user.
|
||||
func (s *Signer) Issue(userID, email string) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(SessionLifetime)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: "librenotes",
|
||||
Subject: userID,
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := tok.SignedString(s.secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
// Verify parses and validates a JWT, returning the claims on success.
|
||||
func (s *Signer) Verify(token string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
parsed, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return s.secret, nil
|
||||
}, jwt.WithValidMethods([]string{"HS256"}))
|
||||
if err != nil || !parsed.Valid {
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
if claims.UserID == "" {
|
||||
return nil, ErrInvalidJWT
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
62
internal/auth/mailer.go
Normal file
62
internal/auth/mailer.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Mailer sends magic-link emails. The interface keeps the auth flow
|
||||
// decoupled from any concrete delivery backend.
|
||||
type Mailer interface {
|
||||
SendMagicLink(ctx context.Context, email, link string) error
|
||||
}
|
||||
|
||||
// LogMailer writes the magic link to a writer instead of sending an
|
||||
// email. Useful for development and tests.
|
||||
type LogMailer struct{ W io.Writer }
|
||||
|
||||
func (m LogMailer) SendMagicLink(_ context.Context, email, link string) error {
|
||||
if m.W == nil {
|
||||
return nil
|
||||
}
|
||||
_, err := fmt.Fprintf(m.W, "magic link for %s: %s\n", email, link)
|
||||
return err
|
||||
}
|
||||
|
||||
// SMTPMailer delivers via plain SMTP with optional auth. TLS is the
|
||||
// caller's responsibility (use submission port 587 with STARTTLS or 465
|
||||
// with implicit TLS — the stdlib smtp package handles STARTTLS via
|
||||
// SendMail when the server supports it).
|
||||
type SMTPMailer struct {
|
||||
Host string
|
||||
Port string
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
}
|
||||
|
||||
func (m SMTPMailer) SendMagicLink(_ context.Context, to, link string) error {
|
||||
if m.Host == "" || m.From == "" {
|
||||
return fmt.Errorf("smtp not configured")
|
||||
}
|
||||
addr := m.Host + ":" + m.Port
|
||||
var auth smtp.Auth
|
||||
if m.Username != "" {
|
||||
auth = smtp.PlainAuth("", m.Username, m.Password, m.Host)
|
||||
}
|
||||
subject := "Your librenotes login link"
|
||||
body := fmt.Sprintf("Click to sign in to librenotes:\n\n%s\n\nThis link is valid for 15 minutes and can be used once.\n", link)
|
||||
msg := strings.Join([]string{
|
||||
"From: " + m.From,
|
||||
"To: " + to,
|
||||
"Subject: " + subject,
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: text/plain; charset=utf-8",
|
||||
"",
|
||||
body,
|
||||
}, "\r\n")
|
||||
return smtp.SendMail(addr, auth, m.From, []string{to}, []byte(msg))
|
||||
}
|
||||
61
internal/auth/ratelimit.go
Normal file
61
internal/auth/ratelimit.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRateLimited indicates that too many requests have been made for a
|
||||
// given key within the configured window.
|
||||
var ErrRateLimited = errors.New("rate limited")
|
||||
|
||||
// RateLimiter enforces a fixed-window rate limit per email using the
|
||||
// login_attempts table. The DB-backed approach survives restarts and
|
||||
// works across multiple processes sharing the same SQLite file.
|
||||
type RateLimiter struct {
|
||||
db *sql.DB
|
||||
window time.Duration
|
||||
max int
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a limiter with the given window and maximum
|
||||
// attempts. Default for magic-link login: 5 per 15 min.
|
||||
func NewRateLimiter(db *sql.DB, window time.Duration, max int) *RateLimiter {
|
||||
return &RateLimiter{db: db, window: window, max: max}
|
||||
}
|
||||
|
||||
// Check records an attempt for email and returns ErrRateLimited if the
|
||||
// number of attempts within the window exceeds max.
|
||||
func (r *RateLimiter) Check(ctx context.Context, email string) error {
|
||||
now := time.Now().UTC()
|
||||
cutoff := now.Add(-r.window).Unix()
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`DELETE FROM login_attempts WHERE created_at < ?`, cutoff); err != nil {
|
||||
return fmt.Errorf("prune: %w", err)
|
||||
}
|
||||
var count int
|
||||
if err := tx.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM login_attempts WHERE email = ? AND created_at >= ?`,
|
||||
email, cutoff).Scan(&count); err != nil {
|
||||
return fmt.Errorf("count: %w", err)
|
||||
}
|
||||
if count >= r.max {
|
||||
return ErrRateLimited
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
`INSERT INTO login_attempts (email, created_at) VALUES (?, ?)`,
|
||||
email, now.Unix()); err != nil {
|
||||
return fmt.Errorf("insert: %w", err)
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
147
internal/auth/service.go
Normal file
147
internal/auth/service.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"git.librete.ch/public/librenotes/internal/storage"
|
||||
)
|
||||
|
||||
// nowFn is overridable in tests.
|
||||
var nowFn = func() time.Time { return time.Now().UTC() }
|
||||
|
||||
// Service orchestrates the magic-link login flow: request a link by
|
||||
// email and verify a returned link to issue a JWT session.
|
||||
type Service struct {
|
||||
users *storage.UserStore
|
||||
tokens *TokenStore
|
||||
limiter *RateLimiter
|
||||
mailer Mailer
|
||||
signer *Signer
|
||||
baseURL string
|
||||
dataDir string
|
||||
}
|
||||
|
||||
// Config bundles dependencies for NewService.
|
||||
type Config struct {
|
||||
Users *storage.UserStore
|
||||
Tokens *TokenStore
|
||||
Limiter *RateLimiter
|
||||
Mailer Mailer
|
||||
Signer *Signer
|
||||
// BaseURL is the public origin used to build verification links,
|
||||
// e.g. "https://librenot.es".
|
||||
BaseURL string
|
||||
// DataDir is the parent directory under which per-user note
|
||||
// directories are created on first login.
|
||||
DataDir string
|
||||
}
|
||||
|
||||
// NewService validates and assembles a Service.
|
||||
func NewService(c Config) (*Service, error) {
|
||||
if c.Users == nil || c.Tokens == nil || c.Limiter == nil || c.Mailer == nil || c.Signer == nil {
|
||||
return nil, fmt.Errorf("auth: missing dependency")
|
||||
}
|
||||
if c.BaseURL == "" {
|
||||
return nil, fmt.Errorf("auth: BaseURL required")
|
||||
}
|
||||
if c.DataDir == "" {
|
||||
return nil, fmt.Errorf("auth: DataDir required")
|
||||
}
|
||||
return &Service{
|
||||
users: c.Users,
|
||||
tokens: c.Tokens,
|
||||
limiter: c.Limiter,
|
||||
mailer: c.Mailer,
|
||||
signer: c.Signer,
|
||||
baseURL: strings.TrimRight(c.BaseURL, "/"),
|
||||
dataDir: c.DataDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequestLogin generates a magic link and emails it. The user is
|
||||
// auto-created on first login. Returns ErrRateLimited if the email has
|
||||
// exceeded the limiter window.
|
||||
func (s *Service) RequestLogin(ctx context.Context, email string) error {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if !looksLikeEmail(email) {
|
||||
return fmt.Errorf("invalid email")
|
||||
}
|
||||
if err := s.limiter.Check(ctx, email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
id := uuid.NewString()
|
||||
path := filepath.Join(s.dataDir, id)
|
||||
created, cerr := s.users.Create(ctx, storage.User{
|
||||
ID: id,
|
||||
Email: email,
|
||||
StoragePath: path,
|
||||
})
|
||||
if cerr != nil {
|
||||
return fmt.Errorf("create user: %w", cerr)
|
||||
}
|
||||
user = created
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("lookup user: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := s.tokens.Issue(ctx, user.ID, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
link := s.baseURL + "/auth/verify?token=" + url.QueryEscape(plaintext)
|
||||
if err := s.mailer.SendMagicLink(ctx, email, link); err != nil {
|
||||
return fmt.Errorf("send mail: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyResult holds the outcome of a successful magic-link verification.
|
||||
type VerifyResult struct {
|
||||
JWT string
|
||||
User storage.User
|
||||
}
|
||||
|
||||
// Verify consumes a magic-link token and returns a signed session JWT.
|
||||
func (s *Service) Verify(ctx context.Context, token string) (VerifyResult, error) {
|
||||
userID, err := s.tokens.Consume(ctx, token)
|
||||
if err != nil {
|
||||
return VerifyResult{}, err
|
||||
}
|
||||
user, err := s.users.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return VerifyResult{}, fmt.Errorf("load user: %w", err)
|
||||
}
|
||||
now := nowFn()
|
||||
if err := s.users.UpdateLastLogin(ctx, user.ID, now); err != nil {
|
||||
return VerifyResult{}, fmt.Errorf("update last login: %w", err)
|
||||
}
|
||||
jwtStr, err := s.signer.Issue(user.ID, user.Email)
|
||||
if err != nil {
|
||||
return VerifyResult{}, err
|
||||
}
|
||||
return VerifyResult{JWT: jwtStr, User: user}, nil
|
||||
}
|
||||
|
||||
// looksLikeEmail does a minimal sanity check; full validation is left
|
||||
// to the SMTP server. We just want to reject obvious garbage.
|
||||
func looksLikeEmail(s string) bool {
|
||||
at := strings.IndexByte(s, '@')
|
||||
if at <= 0 || at == len(s)-1 {
|
||||
return false
|
||||
}
|
||||
if strings.ContainsAny(s, " \t\r\n") {
|
||||
return false
|
||||
}
|
||||
return strings.IndexByte(s[at+1:], '.') >= 0
|
||||
}
|
||||
137
internal/auth/tokens.go
Normal file
137
internal/auth/tokens.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Package auth implements magic-link authentication and JWT issuance
|
||||
// for the librenotes multi-tenant backend.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenLifetime is how long a magic link is valid after creation.
|
||||
const TokenLifetime = 15 * time.Minute
|
||||
|
||||
// ErrInvalidToken is returned for unknown, expired, or already-used tokens.
|
||||
var ErrInvalidToken = errors.New("invalid or expired token")
|
||||
|
||||
// MagicToken records a single magic-link issuance. The plaintext token
|
||||
// is never stored — only its SHA-256 hash. The plaintext lives only in
|
||||
// the email sent to the user.
|
||||
type MagicToken struct {
|
||||
UserID string
|
||||
Email string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// TokenStore persists magic-link token hashes.
|
||||
type TokenStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewTokenStore wraps a database handle.
|
||||
func NewTokenStore(db *sql.DB) *TokenStore { return &TokenStore{db: db} }
|
||||
|
||||
// Issue generates a new random token, stores its hash bound to userID,
|
||||
// and returns the plaintext token (caller emails it to the user).
|
||||
func (s *TokenStore) Issue(ctx context.Context, userID, email string) (string, error) {
|
||||
plaintext, err := randomToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
hash := hashToken(plaintext)
|
||||
_, err = s.db.ExecContext(ctx,
|
||||
`INSERT INTO magic_tokens (token_hash, user_id, email, created_at, expires_at) VALUES (?, ?, ?, ?, ?)`,
|
||||
hash, userID, email, now.Unix(), now.Add(TokenLifetime).Unix())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("insert magic token: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Consume validates the plaintext token. If valid, it marks the token
|
||||
// used (single-use) and returns the bound user ID. Comparison goes
|
||||
// through subtle.ConstantTimeCompare; the hash lookup gives us O(1)
|
||||
// retrieval without leaking timing for the row scan, and the constant-
|
||||
// time compare protects against any residual differences.
|
||||
func (s *TokenStore) Consume(ctx context.Context, plaintext string) (userID string, err error) {
|
||||
hash := hashToken(plaintext)
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("begin: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
var (
|
||||
storedHash string
|
||||
uid string
|
||||
expires int64
|
||||
used sql.NullInt64
|
||||
)
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`SELECT token_hash, user_id, expires_at, used_at FROM magic_tokens WHERE token_hash = ?`,
|
||||
hash).Scan(&storedHash, &uid, &expires, &used)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("select magic token: %w", err)
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(storedHash), []byte(hash)) != 1 {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
if used.Valid {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
if time.Now().UTC().Unix() > expires {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
res, err := tx.ExecContext(ctx,
|
||||
`UPDATE magic_tokens SET used_at = ? WHERE token_hash = ? AND used_at IS NULL`,
|
||||
time.Now().UTC().Unix(), hash)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("mark used: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n != 1 {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// PurgeExpired deletes all expired or used tokens older than retention.
|
||||
// Callers can run this on a timer to keep the table small.
|
||||
func (s *TokenStore) PurgeExpired(ctx context.Context, retention time.Duration) error {
|
||||
cutoff := time.Now().UTC().Add(-retention).Unix()
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM magic_tokens WHERE expires_at < ? OR (used_at IS NOT NULL AND used_at < ?)`,
|
||||
cutoff, cutoff)
|
||||
if err != nil {
|
||||
return fmt.Errorf("purge: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func randomToken() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("rand: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func hashToken(plaintext string) string {
|
||||
sum := sha256.Sum256([]byte(plaintext))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -47,4 +47,18 @@ var migrations = []string{
|
||||
storage_path TEXT NOT NULL
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)`,
|
||||
`CREATE TABLE IF NOT EXISTS magic_tokens (
|
||||
token_hash TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
email TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
used_at INTEGER
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_magic_tokens_expires ON magic_tokens(expires_at)`,
|
||||
`CREATE TABLE IF NOT EXISTS login_attempts (
|
||||
email TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_login_attempts ON login_attempts(email, created_at)`,
|
||||
}
|
||||
|
||||
@@ -79,6 +79,20 @@ func (s *UserStore) GetByEmail(ctx context.Context, email string) (User, error)
|
||||
return s.scanOne(ctx, `SELECT id, email, created_at, last_login_at, storage_path FROM users WHERE email = ?`, normaliseEmail(email))
|
||||
}
|
||||
|
||||
// SetStoragePath updates the per-user storage path. Used during the
|
||||
// first-login auto-create flow once the UUID is known.
|
||||
func (s *UserStore) SetStoragePath(ctx context.Context, id, path string) error {
|
||||
res, err := s.db.ExecContext(ctx, `UPDATE users SET storage_path = ? WHERE id = ?`, path, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update storage_path: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastLogin records a successful login at the given instant.
|
||||
func (s *UserStore) UpdateLastLogin(ctx context.Context, id string, at time.Time) error {
|
||||
res, err := s.db.ExecContext(ctx, `UPDATE users SET last_login_at = ? WHERE id = ?`, at.Unix(), id)
|
||||
|
||||
Reference in New Issue
Block a user