harden auth config and optimize feed pulling

This commit is contained in:
Yuan
2026-02-09 21:58:51 +08:00
parent fa3f7f911b
commit f6e6a0a1bf
26 changed files with 537 additions and 94 deletions
+12 -2
View File
@@ -2,7 +2,9 @@
FUSION_DB_PATH=fusion.db
# Authentication
FUSION_PASSWORD=admin
FUSION_PASSWORD=changeme
# Explicitly allow empty password (not recommended)
# FUSION_ALLOW_EMPTY_PASSWORD=false
# Server Configuration
FUSION_PORT=8080
@@ -20,6 +22,14 @@ FUSION_PULL_CONCURRENCY=10
# Maximum backoff time in seconds (default: 604800 = 7 days)
FUSION_PULL_MAX_BACKOFF=604800
# Login rate limiting
# Max failed attempts per window (default: 10)
FUSION_LOGIN_RATE_LIMIT=10
# Window size in seconds (default: 60)
FUSION_LOGIN_WINDOW=60
# Block duration in seconds after limit is exceeded (default: 300)
FUSION_LOGIN_BLOCK=300
# Logging Configuration
# Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
FUSION_LOG_LEVEL=INFO
@@ -35,7 +45,7 @@ FUSION_LOG_FORMAT=auto
# FUSION_OIDC_CLIENT_ID=
# FUSION_OIDC_CLIENT_SECRET=
# Callback URL (default: auto-detected from Host header as {scheme}://{host}/api/oidc/callback)
# Callback URL (required when OIDC issuer is configured)
# FUSION_OIDC_REDIRECT_URI=
# Restrict login to a specific user identity (email or subject claim, optional)
+11 -1
View File
@@ -2,7 +2,9 @@
FUSION_DB_PATH=fusion.db
# Authentication
FUSION_PASSWORD=admin
FUSION_PASSWORD=changeme
# Explicitly allow empty password (not recommended)
# FUSION_ALLOW_EMPTY_PASSWORD=false
# Server Configuration
FUSION_PORT=8080
@@ -20,6 +22,14 @@ FUSION_PULL_CONCURRENCY=10
# Maximum backoff time in seconds (default: 604800 = 7 days)
FUSION_PULL_MAX_BACKOFF=604800
# Login rate limiting
# Max failed attempts per window (default: 10)
FUSION_LOGIN_RATE_LIMIT=10
# Window size in seconds (default: 60)
FUSION_LOGIN_WINDOW=60
# Block duration in seconds after limit is exceeded (default: 300)
FUSION_LOGIN_BLOCK=300
# Logging Configuration
# Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
FUSION_LOG_LEVEL=INFO
+4 -1
View File
@@ -27,7 +27,10 @@ func main() {
}
func run() error {
cfg := config.Load()
cfg, err := config.Load()
if err != nil {
return err
}
setupLogger(cfg)
st, err := store.New(cfg.DBPath)
-10
View File
@@ -84,16 +84,6 @@ func (a *OIDCAuthenticator) AuthURL() (authURL string, err error) {
return url, nil
}
// SetRedirectURI sets the redirect URI dynamically (for auto-detection from Host header).
func (a *OIDCAuthenticator) SetRedirectURI(uri string) {
a.oauth2Config.RedirectURL = uri
}
// RedirectURI returns the currently configured redirect URI.
func (a *OIDCAuthenticator) RedirectURI() string {
return a.oauth2Config.RedirectURL
}
// Callback exchanges the authorization code for tokens and verifies the ID token.
// Returns a user identifier (email or subject claim).
func (a *OIDCAuthenticator) Callback(ctx context.Context, state, code string) (string, error) {
+77 -15
View File
@@ -1,14 +1,15 @@
package config
import (
"fmt"
"os"
"strconv"
"strings"
)
type Config struct {
DBPath string
Password string // Plaintext password from env
Host string // TODO parse and use
Port int
PullInterval int // Pull interval in seconds (default: 1800 = 30 min)
@@ -16,6 +17,10 @@ type Config struct {
PullConcurrency int // Max concurrent pulls (default: 10)
PullMaxBackoff int // Max backoff time in seconds (default: 604800 = 7 days)
LoginRateLimit int // Max failed login attempts per window (default: 10)
LoginWindow int // Login rate limit window in seconds (default: 60)
LoginBlock int // Login block duration in seconds (default: 300)
LogLevel string // Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
LogFormat string // Log format: text, json, auto (default: auto)
@@ -23,11 +28,11 @@ type Config struct {
OIDCIssuer string // OIDC provider URL
OIDCClientID string // OAuth2 client ID
OIDCClientSecret string // OAuth2 client secret
OIDCRedirectURI string // Callback URL (default: auto-detect from Host header)
OIDCRedirectURI string // Callback URL (required when OIDC is enabled)
OIDCAllowedUser string // Optional: restrict to specific user identity (email or sub)
}
func Load() *Config {
func Load() (*Config, error) {
// Backward compatible env vars:
// - DB (legacy) -> FUSION_DB_PATH
// - PASSWORD (legacy) -> FUSION_PASSWORD
@@ -44,8 +49,14 @@ func Load() *Config {
if password == "" {
password = os.Getenv("PASSWORD")
}
if password == "" {
password = "admin" // TODO allow empty password
allowEmptyPassword, err := getEnvBool("FUSION_ALLOW_EMPTY_PASSWORD", false)
if err != nil {
return nil, err
}
if strings.TrimSpace(password) == "" && !allowEmptyPassword {
return nil, fmt.Errorf("FUSION_PASSWORD is required (or set FUSION_ALLOW_EMPTY_PASSWORD=true)")
}
port := os.Getenv("FUSION_PORT")
@@ -57,7 +68,40 @@ func Load() *Config {
}
parsedPort, err := strconv.Atoi(port)
if err != nil {
panic(err)
return nil, fmt.Errorf("invalid FUSION_PORT: %w", err)
}
if parsedPort <= 0 || parsedPort > 65535 {
return nil, fmt.Errorf("invalid FUSION_PORT: must be in range 1-65535")
}
pullInterval, err := getEnvInt("FUSION_PULL_INTERVAL", 1800, 1)
if err != nil {
return nil, err
}
pullTimeout, err := getEnvInt("FUSION_PULL_TIMEOUT", 30, 1)
if err != nil {
return nil, err
}
pullConcurrency, err := getEnvInt("FUSION_PULL_CONCURRENCY", 10, 1)
if err != nil {
return nil, err
}
pullMaxBackoff, err := getEnvInt("FUSION_PULL_MAX_BACKOFF", 604800, 1)
if err != nil {
return nil, err
}
loginRateLimit, err := getEnvInt("FUSION_LOGIN_RATE_LIMIT", 10, 1)
if err != nil {
return nil, err
}
loginWindow, err := getEnvInt("FUSION_LOGIN_WINDOW", 60, 1)
if err != nil {
return nil, err
}
loginBlock, err := getEnvInt("FUSION_LOGIN_BLOCK", 300, 1)
if err != nil {
return nil, err
}
logLevel := os.Getenv("FUSION_LOG_LEVEL")
@@ -74,10 +118,13 @@ func Load() *Config {
DBPath: dbPath,
Password: password,
Port: parsedPort,
PullInterval: getEnvInt("FUSION_PULL_INTERVAL", 1800),
PullTimeout: getEnvInt("FUSION_PULL_TIMEOUT", 30),
PullConcurrency: getEnvInt("FUSION_PULL_CONCURRENCY", 10),
PullMaxBackoff: getEnvInt("FUSION_PULL_MAX_BACKOFF", 604800),
PullInterval: pullInterval,
PullTimeout: pullTimeout,
PullConcurrency: pullConcurrency,
PullMaxBackoff: pullMaxBackoff,
LoginRateLimit: loginRateLimit,
LoginWindow: loginWindow,
LoginBlock: loginBlock,
LogLevel: logLevel,
LogFormat: logFormat,
@@ -86,17 +133,32 @@ func Load() *Config {
OIDCClientSecret: os.Getenv("FUSION_OIDC_CLIENT_SECRET"),
OIDCRedirectURI: os.Getenv("FUSION_OIDC_REDIRECT_URI"),
OIDCAllowedUser: os.Getenv("FUSION_OIDC_ALLOWED_USER"),
}
}, nil
}
func getEnvInt(key string, defaultVal int) int {
func getEnvInt(key string, defaultVal, minVal int) (int, error) {
val := os.Getenv(key)
if val == "" {
return defaultVal
return defaultVal, nil
}
parsed, err := strconv.Atoi(val)
if err != nil {
return defaultVal
return 0, fmt.Errorf("invalid %s: %w", key, err)
}
return parsed
if parsed < minVal {
return 0, fmt.Errorf("invalid %s: must be >= %d", key, minVal)
}
return parsed, nil
}
func getEnvBool(key string, defaultVal bool) (bool, error) {
val := os.Getenv(key)
if val == "" {
return defaultVal, nil
}
parsed, err := strconv.ParseBool(val)
if err != nil {
return false, fmt.Errorf("invalid %s: %w", key, err)
}
return parsed, nil
}
+12 -3
View File
@@ -23,16 +23,19 @@ func (h *Handler) listBookmarks(c *gin.Context) {
if limitStr := c.Query("limit"); limitStr != "" {
val, err := strconv.Atoi(limitStr)
if err != nil {
if err != nil || val <= 0 {
badRequestError(c, "invalid limit")
return
}
if val > maxListLimit {
val = maxListLimit
}
limit = val
}
if offsetStr := c.Query("offset"); offsetStr != "" {
val, err := strconv.Atoi(offsetStr)
if err != nil {
if err != nil || val < 0 {
badRequestError(c, "invalid offset")
return
}
@@ -45,7 +48,13 @@ func (h *Handler) listBookmarks(c *gin.Context) {
return
}
listResponse(c, bookmarks, len(bookmarks))
total, err := h.store.CountBookmarks()
if err != nil {
internalError(c, err, "count bookmarks")
return
}
listResponse(c, bookmarks, total)
}
func (h *Handler) getBookmark(c *gin.Context) {
+9
View File
@@ -2,6 +2,7 @@ package handler
import (
"log/slog"
"strconv"
"github.com/gin-gonic/gin"
)
@@ -31,3 +32,11 @@ func badRequestError(c *gin.Context, message string) {
func unauthorizedError(c *gin.Context) {
c.JSON(401, gin.H{"error": "unauthorized"})
}
// tooManyRequestsError returns 429 and sets Retry-After when available.
func tooManyRequestsError(c *gin.Context, retryAfterSec int64) {
if retryAfterSec > 0 {
c.Header("Retry-After", strconv.FormatInt(retryAfterSec, 10))
}
c.JSON(429, gin.H{"error": "too many login attempts"})
}
+7
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"sync"
"github.com/0x2E/fusion/internal/auth"
@@ -23,6 +24,7 @@ type Handler struct {
sessions map[string]bool // sessionID -> valid, in-memory session store
mu sync.RWMutex // protects sessions map
oidcAuth *auth.OIDCAuthenticator // nil when OIDC is disabled
limiter *loginLimiter
}
func New(store *store.Store, config *config.Config, puller interface {
@@ -41,9 +43,14 @@ func New(store *store.Store, config *config.Config, puller interface {
passwordHash: passwordHash,
puller: puller,
sessions: make(map[string]bool),
limiter: newLoginLimiter(config.LoginRateLimit, config.LoginWindow, config.LoginBlock),
}
if config.OIDCIssuer != "" {
if strings.TrimSpace(config.OIDCRedirectURI) == "" {
return nil, fmt.Errorf("FUSION_OIDC_REDIRECT_URI is required when OIDC is enabled")
}
oidcAuth, err := auth.NewOIDC(
context.Background(),
config.OIDCIssuer,
+7 -2
View File
@@ -8,6 +8,8 @@ import (
"github.com/gin-gonic/gin"
)
const maxListLimit = 100
type markItemsReadRequest struct {
IDs []int64 `json:"ids" binding:"required"`
}
@@ -44,10 +46,13 @@ func (h *Handler) listItems(c *gin.Context) {
if limit := c.Query("limit"); limit != "" {
val, err := strconv.Atoi(limit)
if err != nil {
if err != nil || val <= 0 {
badRequestError(c, "invalid limit")
return
}
if val > maxListLimit {
val = maxListLimit
}
params.Limit = val
} else {
params.Limit = 10
@@ -55,7 +60,7 @@ func (h *Handler) listItems(c *gin.Context) {
if offset := c.Query("offset"); offset != "" {
val, err := strconv.Atoi(offset)
if err != nil {
if err != nil || val < 0 {
badRequestError(c, "invalid offset")
return
}
-10
View File
@@ -1,7 +1,6 @@
package handler
import (
"fmt"
"log/slog"
"net/http"
@@ -18,15 +17,6 @@ func (h *Handler) oidcLogin(c *gin.Context) {
return
}
// Auto-detect redirect URI from Host header if not explicitly configured
if h.oidcAuth.RedirectURI() == "" {
scheme := "https"
if !isSecureRequest(c.Request) {
scheme = "http"
}
h.oidcAuth.SetRedirectURI(fmt.Sprintf("%s://%s/api/oidc/callback", scheme, c.Request.Host))
}
authURL, err := h.oidcAuth.AuthURL()
if err != nil {
internalError(c, err, "oidc auth url")
+5 -2
View File
@@ -20,18 +20,21 @@ func (h *Handler) search(c *gin.Context) {
badRequestError(c, "invalid limit")
return
}
if parsed > maxListLimit {
parsed = maxListLimit
}
limit = parsed
}
feeds, err := h.store.SearchFeeds(q)
if err != nil {
c.JSON(500, gin.H{"error": "search feeds: " + err.Error()})
internalError(c, err, "search feeds")
return
}
items, err := h.store.SearchItems(q, limit)
if err != nil {
c.JSON(500, gin.H{"error": "search items: " + err.Error()})
internalError(c, err, "search items")
return
}
+105 -2
View File
@@ -2,12 +2,102 @@ package handler
import (
"net/http"
"sync"
"time"
"github.com/0x2E/fusion/internal/auth"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type loginState struct {
windowStart int64
failures int
blockedTill int64
}
type loginLimiter struct {
mu sync.Mutex
states map[string]loginState
limit int
windowSecs int64
blockSecs int64
lastSweepSec int64
}
func newLoginLimiter(limit, windowSecs, blockSecs int) *loginLimiter {
return &loginLimiter{
states: make(map[string]loginState),
limit: limit,
windowSecs: int64(windowSecs),
blockSecs: int64(blockSecs),
}
}
func (l *loginLimiter) allow(ip string, now time.Time) (bool, int64) {
nowSec := now.Unix()
l.mu.Lock()
defer l.mu.Unlock()
l.sweep(nowSec)
state, ok := l.states[ip]
if !ok {
return true, 0
}
if state.blockedTill > nowSec {
return false, state.blockedTill - nowSec
}
return true, 0
}
func (l *loginLimiter) recordFailure(ip string, now time.Time) {
nowSec := now.Unix()
l.mu.Lock()
defer l.mu.Unlock()
l.sweep(nowSec)
state := l.states[ip]
if state.windowStart == 0 || nowSec-state.windowStart >= l.windowSecs {
state.windowStart = nowSec
state.failures = 0
}
state.failures++
if state.failures >= l.limit {
state.blockedTill = nowSec + l.blockSecs
state.windowStart = nowSec
state.failures = 0
}
l.states[ip] = state
}
func (l *loginLimiter) recordSuccess(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.states, ip)
}
func (l *loginLimiter) sweep(nowSec int64) {
if nowSec-l.lastSweepSec < 60 {
return
}
l.lastSweepSec = nowSec
for ip, state := range l.states {
windowExpired := state.windowStart > 0 && nowSec-state.windowStart >= l.windowSecs
unblocked := state.blockedTill > 0 && state.blockedTill <= nowSec
if (state.blockedTill == 0 && windowExpired) || unblocked {
delete(l.states, ip)
}
}
}
func isSecureRequest(r *http.Request) bool {
if r.TLS != nil {
return true
@@ -16,20 +106,33 @@ func isSecureRequest(r *http.Request) bool {
}
type loginRequest struct {
Password string `json:"password" binding:"required"`
Password *string `json:"password"`
}
func (h *Handler) login(c *gin.Context) {
ip := c.ClientIP()
allowed, retryAfter := h.limiter.allow(ip, time.Now())
if !allowed {
tooManyRequestsError(c, retryAfter)
return
}
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
badRequestError(c, "invalid request")
return
}
if req.Password == nil {
badRequestError(c, "invalid request")
return
}
if err := auth.CheckPassword(h.passwordHash, req.Password); err != nil {
if err := auth.CheckPassword(h.passwordHash, *req.Password); err != nil {
h.limiter.recordFailure(ip, time.Now())
unauthorizedError(c)
return
}
h.limiter.recordSuccess(ip)
h.createSession(c)
dataResponse(c, gin.H{"message": "logged in"})
@@ -0,0 +1,36 @@
package handler
import "testing"
func TestLoginLimiterSweepDeletesExpiredWindowWithoutBlock(t *testing.T) {
limiter := newLoginLimiter(3, 10, 30)
limiter.states["1.1.1.1"] = loginState{windowStart: 10, failures: 1}
limiter.sweep(120)
if _, ok := limiter.states["1.1.1.1"]; ok {
t.Fatal("expected expired non-blocked state to be deleted")
}
}
func TestLoginLimiterSweepDeletesUnblockedState(t *testing.T) {
limiter := newLoginLimiter(3, 60, 30)
limiter.states["2.2.2.2"] = loginState{windowStart: 100, blockedTill: 110}
limiter.sweep(120)
if _, ok := limiter.states["2.2.2.2"]; ok {
t.Fatal("expected unblocked state to be deleted")
}
}
func TestLoginLimiterSweepKeepsActiveBlockedState(t *testing.T) {
limiter := newLoginLimiter(3, 60, 30)
limiter.states["3.3.3.3"] = loginState{windowStart: 100, blockedTill: 170}
limiter.sweep(120)
if _, ok := limiter.states["3.3.3.3"]; !ok {
t.Fatal("expected active blocked state to remain")
}
}
+65
View File
@@ -0,0 +1,65 @@
package handler
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/0x2E/fusion/internal/auth"
"github.com/gin-gonic/gin"
)
func newTestSessionHandler(t *testing.T, password string) *Handler {
t.Helper()
hash, err := auth.HashPassword(password)
if err != nil {
t.Fatalf("hash password: %v", err)
}
return &Handler{
passwordHash: hash,
sessions: make(map[string]bool),
limiter: newLoginLimiter(10, 60, 300),
}
}
func TestLoginRejectsMissingPasswordField(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newTestSessionHandler(t, "secret")
r := gin.New()
r.POST("/api/sessions", h.login)
req := httptest.NewRequest(http.MethodPost, "/api/sessions", strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestLoginAcceptsEmptyPassword(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newTestSessionHandler(t, "")
r := gin.New()
r.POST("/api/sessions", h.login)
req := httptest.NewRequest(http.MethodPost, "/api/sessions", strings.NewReader(`{"password":""}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if cookie := w.Header().Get("Set-Cookie"); !strings.Contains(cookie, "session=") {
t.Fatalf("expected session cookie to be set, got %q", cookie)
}
}
+8 -6
View File
@@ -16,12 +16,14 @@ type Feed struct {
Link string `json:"link"`
SiteURL string `json:"site_url,omitempty"`
LastBuild int64 `json:"last_build"`
Failure string `json:"failure,omitempty"`
Failures int64 `json:"failures"`
Suspended bool `json:"suspended"`
Proxy string `json:"proxy,omitempty"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
// LastFailureAt is Unix timestamp of the most recent pull failure.
LastFailureAt int64 `json:"last_failure_at"`
Failure string `json:"failure,omitempty"`
Failures int64 `json:"failures"`
Suspended bool `json:"suspended"`
Proxy string `json:"proxy,omitempty"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
UnreadCount int64 `json:"unread_count"`
ItemCount int64 `json:"item_count"`
+38 -6
View File
@@ -3,15 +3,37 @@ package httpc
import (
"net/http"
"net/url"
"strconv"
"sync"
"time"
)
type clientPool struct {
mu sync.RWMutex
clients map[string]*http.Client
}
var defaultClientPool = &clientPool{clients: make(map[string]*http.Client)}
// NewClient creates HTTP client with specified timeout and optional proxy.
// Returns client configured for HTTP/2 with keep-alives disabled.
// Clients are reused by (timeout, proxy) to keep connections warm.
func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
key := proxyURL + "|" + strconv.FormatInt(timeout.Milliseconds(), 10)
defaultClientPool.mu.RLock()
if client, ok := defaultClientPool.clients[key]; ok {
defaultClientPool.mu.RUnlock()
return client, nil
}
defaultClientPool.mu.RUnlock()
transport := &http.Transport{
DisableKeepAlives: true,
ForceAttemptHTTP2: true,
ForceAttemptHTTP2: true,
MaxIdleConns: 128,
MaxIdleConnsPerHost: 16,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: timeout,
}
if proxyURL != "" {
@@ -22,14 +44,24 @@ func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
transport.Proxy = http.ProxyURL(proxy)
}
return &http.Client{
client := &http.Client{
Timeout: timeout,
Transport: transport,
}, nil
}
defaultClientPool.mu.Lock()
if existing, ok := defaultClientPool.clients[key]; ok {
defaultClientPool.mu.Unlock()
transport.CloseIdleConnections()
return existing, nil
}
defaultClientPool.clients[key] = client
defaultClientPool.mu.Unlock()
return client, nil
}
// SetDefaultHeaders adds default headers required for feed fetching.
func SetDefaultHeaders(req *http.Request) {
req.Header.Set("User-Agent", "fusion/1.0")
req.Header.Set("Connection", "close")
}
+5 -1
View File
@@ -38,7 +38,11 @@ func ShouldSkip(feed *model.Feed, interval, maxBackoff time.Duration) bool {
// Skip if in backoff period
if feed.Failures > 0 {
backoff := CalculateBackoff(interval, feed.Failures, maxBackoff)
nextPull := feed.LastBuild + int64(backoff.Seconds())
base := feed.LastFailureAt
if base <= 0 {
base = feed.LastBuild
}
nextPull := base + int64(backoff.Seconds())
if now < nextPull {
return true
}
+13 -15
View File
@@ -94,23 +94,21 @@ func (p *Puller) pullFeed(ctx context.Context, feed *model.Feed) {
return
}
newCount := 0
inputs := make([]store.BatchCreateItemInput, 0, len(items))
for _, item := range items {
exists, err := p.store.ItemExists(feed.ID, item.GUID)
if err != nil {
p.logger.Error("failed to check item existence", "feed_id", feed.ID, "error", err)
continue
}
if exists {
continue
}
inputs = append(inputs, store.BatchCreateItemInput{
GUID: item.GUID,
Title: item.Title,
Link: item.Link,
Content: item.Content,
PubDate: item.PubDate,
})
}
_, err = p.store.CreateItem(feed.ID, item.GUID, item.Title, item.Link, item.Content, item.PubDate)
if err != nil {
p.logger.Error("failed to create item", "feed_id", feed.ID, "error", err)
continue
}
newCount++
newCount, err := p.store.BatchCreateItemsIgnore(feed.ID, inputs)
if err != nil {
p.logger.Error("failed to batch create items", "feed_id", feed.ID, "error", err)
return
}
if err := p.store.UpdateFeedLastBuild(feed.ID, time.Now().Unix()); err != nil {
+6
View File
@@ -98,3 +98,9 @@ func (s *Store) BookmarkExists(link string) (bool, error) {
err := s.db.QueryRow(`SELECT EXISTS(SELECT 1 FROM bookmarks WHERE link = :link)`, sql.Named("link", link)).Scan(&exists)
return exists, err
}
func (s *Store) CountBookmarks() (int, error) {
var count int
err := s.db.QueryRow(`SELECT COUNT(*) FROM bookmarks`).Scan(&count)
return count, err
}
+11 -8
View File
@@ -11,11 +11,14 @@ import (
func (s *Store) ListFeeds() ([]*model.Feed, error) {
rows, err := s.db.Query(`
SELECT f.id, f.group_id, f.name, f.link, f.site_url, f.last_build,
SELECT f.id, f.group_id, f.name, f.link, f.site_url, f.last_build, f.last_failure_at,
f.failure, f.failures, f.suspended, f.proxy, f.created_at, f.updated_at,
(SELECT COUNT(*) FROM items WHERE feed_id = f.id AND unread = 1) AS unread_count,
(SELECT COUNT(*) FROM items WHERE feed_id = f.id) AS item_count
COALESCE(SUM(CASE WHEN i.unread = 1 THEN 1 ELSE 0 END), 0) AS unread_count,
COALESCE(COUNT(i.id), 0) AS item_count
FROM feeds f
LEFT JOIN items i ON i.feed_id = f.id
GROUP BY f.id, f.group_id, f.name, f.link, f.site_url, f.last_build, f.last_failure_at,
f.failure, f.failures, f.suspended, f.proxy, f.created_at, f.updated_at
ORDER BY f.id
`)
if err != nil {
@@ -27,7 +30,7 @@ func (s *Store) ListFeeds() ([]*model.Feed, error) {
for rows.Next() {
f := &model.Feed{}
var suspended int
if err := rows.Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt, &f.UnreadCount, &f.ItemCount); err != nil {
if err := rows.Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.LastFailureAt, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt, &f.UnreadCount, &f.ItemCount); err != nil {
return nil, err
}
f.Suspended = intToBool(suspended)
@@ -40,10 +43,10 @@ func (s *Store) GetFeed(id int64) (*model.Feed, error) {
f := &model.Feed{}
var suspended int
err := s.db.QueryRow(`
SELECT id, group_id, name, link, site_url, last_build, failure, failures, suspended, proxy, created_at, updated_at
SELECT id, group_id, name, link, site_url, last_build, last_failure_at, failure, failures, suspended, proxy, created_at, updated_at
FROM feeds
WHERE id = :id
`, sql.Named("id", id)).Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt)
`, sql.Named("id", id)).Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.LastFailureAt, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("%w: feed", ErrNotFound)
@@ -205,7 +208,7 @@ func (s *Store) DeleteFeed(id int64) error {
func (s *Store) UpdateFeedLastBuild(id int64, lastBuild int64) error {
_, err := s.db.Exec(`
UPDATE feeds
SET last_build = :last_build, failures = 0, failure = '', updated_at = unixepoch()
SET last_build = :last_build, last_failure_at = 0, failures = 0, failure = '', updated_at = unixepoch()
WHERE id = :id
`, sql.Named("last_build", lastBuild), sql.Named("id", id))
return err
@@ -214,7 +217,7 @@ func (s *Store) UpdateFeedLastBuild(id int64, lastBuild int64) error {
func (s *Store) UpdateFeedFailure(id int64, failure string) error {
_, err := s.db.Exec(`
UPDATE feeds
SET failures = failures + 1, failure = :failure, updated_at = unixepoch()
SET failures = failures + 1, failure = :failure, last_failure_at = unixepoch(), updated_at = unixepoch()
WHERE id = :id
`, sql.Named("failure", failure), sql.Named("id", id))
return err
+10
View File
@@ -294,6 +294,9 @@ func TestUpdateFeedLastBuild(t *testing.T) {
if updated.Failures != 0 {
t.Error("expected failures to be cleared")
}
if updated.LastFailureAt != 0 {
t.Error("expected last_failure_at to be cleared")
}
}
func TestUpdateFeedFailure(t *testing.T) {
@@ -327,6 +330,10 @@ func TestUpdateFeedFailure(t *testing.T) {
if updated1.Failures != 1 {
t.Errorf("expected failures count to be 1, got %d", updated1.Failures)
}
if updated1.LastFailureAt == 0 {
t.Error("expected last_failure_at to be set after failure")
}
firstFailureAt := updated1.LastFailureAt
errorMsg2 := "second error"
if err := store.UpdateFeedFailure(feed.ID, errorMsg2); err != nil {
@@ -345,6 +352,9 @@ func TestUpdateFeedFailure(t *testing.T) {
if updated2.Failures != 2 {
t.Errorf("expected failures count to be 2, got %d", updated2.Failures)
}
if updated2.LastFailureAt < firstFailureAt {
t.Error("expected last_failure_at to be monotonic")
}
}
func TestUpdateFeedSiteURLIfEmpty(t *testing.T) {
+61
View File
@@ -122,6 +122,67 @@ func (s *Store) CreateItem(feedID int64, guid, title, link, content string, pubD
return s.GetItem(id)
}
type BatchCreateItemInput struct {
GUID string
Title string
Link string
Content string
PubDate int64
}
// BatchCreateItemsIgnore inserts items in one transaction and ignores duplicates by (feed_id, guid).
// Returns the number of newly inserted rows.
func (s *Store) BatchCreateItemsIgnore(feedID int64, inputs []BatchCreateItemInput) (int, error) {
if len(inputs) == 0 {
return 0, nil
}
tx, err := s.db.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`
INSERT INTO items (feed_id, guid, title, link, content, pub_date)
VALUES (:feed_id, :guid, :title, :link, :content, :pub_date)
ON CONFLICT(feed_id, guid) DO NOTHING
`)
if err != nil {
return 0, err
}
defer stmt.Close()
created := 0
for _, input := range inputs {
result, err := stmt.Exec(
sql.Named("feed_id", feedID),
sql.Named("guid", input.GUID),
sql.Named("title", input.Title),
sql.Named("link", input.Link),
sql.Named("content", input.Content),
sql.Named("pub_date", input.PubDate),
)
if err != nil {
return 0, err
}
affected, err := result.RowsAffected()
if err != nil {
return 0, err
}
if affected > 0 {
created++
}
}
if err := tx.Commit(); err != nil {
return 0, err
}
return created, nil
}
func (s *Store) UpdateItemUnread(id int64, unread bool) error {
result, err := s.db.Exec(`UPDATE items SET unread = :unread WHERE id = :id`,
sql.Named("unread", boolToInt(unread)), sql.Named("id", id))
+12 -9
View File
@@ -1,6 +1,7 @@
package store
import (
"database/sql"
"testing"
)
@@ -39,15 +40,17 @@ func TestMigrationVersionTracking(t *testing.T) {
t.Error("schema_migrations table is empty, but migrations should have been applied")
}
// Verify version 1 was applied
var applied bool
err = store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = 1)").Scan(&applied)
if err != nil {
t.Fatalf("failed to check version 1: %v", err)
}
if !applied {
t.Error("migration version 1 was not applied")
// Verify all expected versions were applied
versions := []int{1, 2, 3}
for _, version := range versions {
var applied bool
err = store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = :version)", sql.Named("version", version)).Scan(&applied)
if err != nil {
t.Fatalf("failed to check version %d: %v", version, err)
}
if !applied {
t.Errorf("migration version %d was not applied", version)
}
}
}
@@ -0,0 +1 @@
ALTER TABLE feeds ADD COLUMN last_failure_at INTEGER NOT NULL DEFAULT 0;
@@ -0,0 +1 @@
CREATE INDEX IF NOT EXISTS idx_items_feed_unread ON items(feed_id, unread);
+21 -1
View File
@@ -6,17 +6,37 @@
package store
import (
"context"
"database/sql"
"fmt"
"sync"
_ "modernc.org/sqlite"
"modernc.org/sqlite"
)
type Store struct {
db *sql.DB
}
var sqliteHookOnce sync.Once
func New(dbPath string) (*Store, error) {
sqliteHookOnce.Do(func() {
sqlite.RegisterConnectionHook(func(conn sqlite.ExecQuerierContext, _ string) error {
ctx := context.Background()
if _, err := conn.ExecContext(ctx, "PRAGMA foreign_keys = ON", nil); err != nil {
return fmt.Errorf("enable foreign_keys: %w", err)
}
if _, err := conn.ExecContext(ctx, "PRAGMA busy_timeout = 5000", nil); err != nil {
return fmt.Errorf("set busy_timeout: %w", err)
}
if _, err := conn.ExecContext(ctx, "PRAGMA journal_mode = WAL", nil); err != nil {
return fmt.Errorf("set journal_mode: %w", err)
}
return nil
})
})
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)