harden feed networking and improve store query performance

This commit is contained in:
Yuan
2026-02-09 23:48:35 +08:00
parent e3c009d7d2
commit 67afb207c5
18 changed files with 549 additions and 141 deletions
+7
View File
@@ -9,6 +9,13 @@ FUSION_PASSWORD=changeme
# Server Configuration
FUSION_PORT=8080
# CORS allowed origins (comma-separated, empty means allow all)
# Example: FUSION_CORS_ALLOWED_ORIGINS=https://app.example.com,https://admin.example.com
# FUSION_CORS_ALLOWED_ORIGINS=
# Allow pulling private/localhost feed URLs (default: false)
# FUSION_ALLOW_PRIVATE_FEEDS=false
# Feed Pull Service Configuration
# Pull interval in seconds (default: 1800 = 30 minutes)
FUSION_PULL_INTERVAL=1800
-38
View File
@@ -1,38 +0,0 @@
# Database Configuration
FUSION_DB_PATH=fusion.db
# Authentication
FUSION_PASSWORD=changeme
# Explicitly allow empty password (not recommended)
# FUSION_ALLOW_EMPTY_PASSWORD=false
# Server Configuration
FUSION_PORT=8080
# Feed Pull Service Configuration
# Pull interval in seconds (default: 1800 = 30 minutes)
FUSION_PULL_INTERVAL=1800
# Request timeout in seconds (default: 30)
FUSION_PULL_TIMEOUT=30
# Maximum concurrent pulls (default: 10)
FUSION_PULL_CONCURRENCY=10
# Maximum backoff time in seconds (default: 604800 = 7 days)
FUSION_PULL_MAX_BACKOFF=604800
# Login rate limiting
# Max failed attempts per window (default: 10)
FUSION_LOGIN_RATE_LIMIT=10
# Window size in seconds (default: 60)
FUSION_LOGIN_WINDOW=60
# Block duration in seconds after limit is exceeded (default: 300)
FUSION_LOGIN_BLOCK=300
# Logging Configuration
# Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
FUSION_LOG_LEVEL=INFO
# Log format: text, json, auto (default: auto)
FUSION_LOG_FORMAT=auto
+8 -1
View File
@@ -47,7 +47,14 @@ func run() error {
r := h.SetupRouter()
addr := ":" + strconv.Itoa(cfg.Port)
srv := &http.Server{Addr: addr, Handler: r}
srv := &http.Server{
Addr: addr,
Handler: r,
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 15 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
sigCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
+49 -12
View File
@@ -12,6 +12,10 @@ type Config struct {
Password string // Plaintext password from env
Port int
CORSAllowedOrigins []string // Allowed Origins for CORS. Empty means allow all.
TrustedProxies []string // Trusted reverse proxies for client IP resolution. Empty disables proxy trust.
AllowPrivateFeeds bool // Allow pulling private/localhost feed URLs.
PullInterval int // Pull interval in seconds (default: 1800 = 30 min)
PullTimeout int // Request timeout in seconds (default: 30)
PullConcurrency int // Max concurrent pulls (default: 10)
@@ -104,6 +108,14 @@ func Load() (*Config, error) {
return nil, err
}
corsAllowedOrigins := parseCSVEnv(os.Getenv("FUSION_CORS_ALLOWED_ORIGINS"))
trustedProxies := parseCSVEnv(os.Getenv("FUSION_TRUSTED_PROXIES"))
allowPrivateFeeds, err := getEnvBool("FUSION_ALLOW_PRIVATE_FEEDS", false)
if err != nil {
return nil, err
}
logLevel := os.Getenv("FUSION_LOG_LEVEL")
if logLevel == "" {
logLevel = "INFO"
@@ -115,18 +127,21 @@ func Load() (*Config, error) {
}
return &Config{
DBPath: dbPath,
Password: password,
Port: parsedPort,
PullInterval: pullInterval,
PullTimeout: pullTimeout,
PullConcurrency: pullConcurrency,
PullMaxBackoff: pullMaxBackoff,
LoginRateLimit: loginRateLimit,
LoginWindow: loginWindow,
LoginBlock: loginBlock,
LogLevel: logLevel,
LogFormat: logFormat,
DBPath: dbPath,
Password: password,
Port: parsedPort,
CORSAllowedOrigins: corsAllowedOrigins,
TrustedProxies: trustedProxies,
AllowPrivateFeeds: allowPrivateFeeds,
PullInterval: pullInterval,
PullTimeout: pullTimeout,
PullConcurrency: pullConcurrency,
PullMaxBackoff: pullMaxBackoff,
LoginRateLimit: loginRateLimit,
LoginWindow: loginWindow,
LoginBlock: loginBlock,
LogLevel: logLevel,
LogFormat: logFormat,
OIDCIssuer: os.Getenv("FUSION_OIDC_ISSUER"),
OIDCClientID: os.Getenv("FUSION_OIDC_CLIENT_ID"),
@@ -162,3 +177,25 @@ func getEnvBool(key string, defaultVal bool) (bool, error) {
}
return parsed, nil
}
func parseCSVEnv(val string) []string {
if strings.TrimSpace(val) == "" {
return nil
}
parts := strings.Split(val, ",")
values := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
values = append(values, part)
}
if len(values) == 0 {
return nil
}
return values
}
+95 -12
View File
@@ -4,12 +4,13 @@ import (
"context"
"errors"
"log/slog"
"net/url"
"net/http"
"strconv"
"strings"
"time"
"github.com/0x2E/feedfinder"
"github.com/0x2E/fusion/internal/pkg/httpc"
"github.com/0x2E/fusion/internal/store"
"github.com/gin-gonic/gin"
"github.com/mmcdole/gofeed"
@@ -56,6 +57,8 @@ type batchCreateFeedItem struct {
SiteURL string `json:"site_url"`
}
const refreshAllTimeout = 30 * time.Minute
func (h *Handler) listFeeds(c *gin.Context) {
feeds, err := h.store.ListFeeds()
if err != nil {
@@ -92,6 +95,10 @@ func (h *Handler) createFeed(c *gin.Context) {
badRequestError(c, "invalid request")
return
}
if err := httpc.ValidateRequestURL(c.Request.Context(), req.Link, h.config.AllowPrivateFeeds); err != nil {
badRequestError(c, "invalid link")
return
}
feed, err := h.store.CreateFeed(req.GroupID, req.Name, req.Link, req.SiteURL, req.Proxy)
if err != nil {
@@ -133,6 +140,10 @@ func (h *Handler) updateFeed(c *gin.Context) {
params.Name = req.Name
}
if req.Link != nil {
if err := httpc.ValidateRequestURL(c.Request.Context(), *req.Link, h.config.AllowPrivateFeeds); err != nil {
badRequestError(c, "invalid link")
return
}
params.Link = req.Link
}
if req.SiteURL != nil {
@@ -194,8 +205,8 @@ func (h *Handler) validateFeed(c *gin.Context) {
}
target := strings.TrimSpace(req.URL)
parsedURL, err := url.ParseRequestURI(target)
if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
allowPrivateFeeds := h.config != nil && h.config.AllowPrivateFeeds
if err := httpc.ValidateRequestURL(c.Request.Context(), target, allowPrivateFeeds); err != nil {
badRequestError(c, "invalid url")
return
}
@@ -209,14 +220,19 @@ func (h *Handler) validateFeed(c *gin.Context) {
}
feeds := normalizeDiscoveredFeeds(found)
if len(feeds) == 0 {
parser := gofeed.NewParser()
parsedFeed, parseErr := parser.ParseURLWithContext(target, ctx)
if parseErr == nil {
title := ""
if parsedFeed != nil {
title = strings.TrimSpace(parsedFeed.Title)
if !allowPrivateFeeds {
filtered := make([]discoveredFeed, 0, len(feeds))
for _, feed := range feeds {
if err := httpc.ValidateRequestURL(ctx, feed.Link, false); err == nil {
filtered = append(filtered, feed)
}
}
feeds = filtered
}
if len(feeds) == 0 {
title, parseErr := h.parseFeedTitle(ctx, target)
if parseErr == nil {
feeds = append(feeds, discoveredFeed{Title: title, Link: target})
}
}
@@ -247,6 +263,42 @@ func normalizeDiscoveredFeeds(found []feedfinder.Feed) []discoveredFeed {
return result
}
func (h *Handler) parseFeedTitle(ctx context.Context, target string) (string, error) {
allowPrivateFeeds := h.config != nil && h.config.AllowPrivateFeeds
client, err := httpc.NewClient(30*time.Second, "", allowPrivateFeeds)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, target, nil)
if err != nil {
return "", err
}
httpc.SetDefaultHeaders(req)
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", errors.New("feed fetch failed")
}
parsedFeed, err := gofeed.NewParser().Parse(resp.Body)
if err != nil {
return "", err
}
if parsedFeed == nil {
return "", nil
}
return strings.TrimSpace(parsedFeed.Title), nil
}
func (h *Handler) refreshFeed(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -278,10 +330,19 @@ func (h *Handler) refreshFeed(c *gin.Context) {
}
func (h *Handler) refreshAllFeeds(c *gin.Context) {
if !h.tryStartRefreshAll() {
dataResponse(c, gin.H{"message": "refresh already running"})
return
}
// Run in background so the HTTP response returns immediately.
go func() {
ctx := context.Background()
if count, err := h.puller.RefreshAll(ctx); err != nil {
defer h.finishRefreshAll()
ctx, cancel := context.WithTimeout(context.Background(), refreshAllTimeout)
defer cancel()
if count, err := h.puller.RefreshAll(ctx); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
slog.Warn("refresh all feeds failed", "refreshed", count, "error", err)
}
}()
@@ -289,6 +350,24 @@ func (h *Handler) refreshAllFeeds(c *gin.Context) {
dataResponse(c, gin.H{"message": "refresh triggered"})
}
func (h *Handler) tryStartRefreshAll() bool {
h.refreshAllMu.Lock()
defer h.refreshAllMu.Unlock()
if h.refreshAllRunning {
return false
}
h.refreshAllRunning = true
return true
}
func (h *Handler) finishRefreshAll() {
h.refreshAllMu.Lock()
h.refreshAllRunning = false
h.refreshAllMu.Unlock()
}
func (h *Handler) batchCreateFeeds(c *gin.Context) {
var req batchCreateFeedsRequest
if err := c.ShouldBindJSON(&req); err != nil {
@@ -298,6 +377,10 @@ func (h *Handler) batchCreateFeeds(c *gin.Context) {
inputs := make([]store.BatchCreateFeedsInput, len(req.Feeds))
for i, f := range req.Feeds {
if err := httpc.ValidateRequestURL(c.Request.Context(), f.Link, h.config.AllowPrivateFeeds); err != nil {
badRequestError(c, "invalid link")
return
}
inputs[i] = store.BatchCreateFeedsInput{
GroupID: f.GroupID,
Name: f.Name,
+53 -12
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
@@ -21,10 +22,14 @@ type Handler struct {
RefreshFeed(ctx context.Context, feedID int64) error
RefreshAll(ctx context.Context) (int, error)
}
sessions map[string]bool // sessionID -> valid, in-memory session store
mu sync.RWMutex // protects sessions map
oidcAuth *auth.OIDCAuthenticator // nil when OIDC is disabled
limiter *loginLimiter
sessions map[string]int64 // sessionID -> unix expiry seconds
mu sync.RWMutex // protects sessions state
oidcAuth *auth.OIDCAuthenticator // nil when OIDC is disabled
limiter *loginLimiter
lastSweep int64
refreshAllMu sync.Mutex
refreshAllRunning bool
}
func New(store *store.Store, config *config.Config, puller interface {
@@ -42,7 +47,7 @@ func New(store *store.Store, config *config.Config, puller interface {
config: config,
passwordHash: passwordHash,
puller: puller,
sessions: make(map[string]bool),
sessions: make(map[string]int64),
limiter: newLoginLimiter(config.LoginRateLimit, config.LoginWindow, config.LoginBlock),
}
@@ -73,6 +78,9 @@ func New(store *store.Store, config *config.Config, puller interface {
func (h *Handler) SetupRouter() *gin.Engine {
r := gin.Default()
if err := h.configureTrustedProxies(r); err != nil {
slog.Warn("failed to configure trusted proxies", "error", err)
}
r.Use(h.corsMiddleware())
@@ -124,11 +132,22 @@ func (h *Handler) SetupRouter() *gin.Engine {
return r
}
func (h *Handler) configureTrustedProxies(r *gin.Engine) error {
if h.config == nil || len(h.config.TrustedProxies) == 0 {
return r.SetTrustedProxies(nil)
}
return r.SetTrustedProxies(h.config.TrustedProxies)
}
func (h *Handler) corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
origin := strings.TrimSpace(c.Request.Header.Get("Origin"))
if origin != "" {
// Cookie-based auth needs a concrete origin ("*" + credentials is rejected by browsers).
if !h.isOriginAllowed(origin) {
c.AbortWithStatus(http.StatusForbidden)
return
}
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Vary", "Origin")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
@@ -147,6 +166,32 @@ func (h *Handler) corsMiddleware() gin.HandlerFunc {
}
}
func (h *Handler) isOriginAllowed(origin string) bool {
if h.config == nil {
return true
}
if len(h.config.CORSAllowedOrigins) == 0 {
return true
}
normalizedOrigin := normalizeOrigin(origin)
for _, allowed := range h.config.CORSAllowedOrigins {
normalizedAllowed := normalizeOrigin(allowed)
if normalizedAllowed == "*" || normalizedAllowed == normalizedOrigin {
return true
}
}
return false
}
func normalizeOrigin(origin string) string {
origin = strings.TrimSpace(origin)
origin = strings.TrimSuffix(origin, "/")
return strings.ToLower(origin)
}
func (h *Handler) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
sessionID, err := c.Cookie("session")
@@ -156,11 +201,7 @@ func (h *Handler) authMiddleware() gin.HandlerFunc {
return
}
h.mu.RLock()
valid := h.sessions[sessionID]
h.mu.RUnlock()
if !valid {
if !h.isSessionValid(sessionID) {
unauthorizedError(c)
c.Abort()
return
+9
View File
@@ -9,6 +9,7 @@ import (
)
const maxListLimit = 100
const maxBatchUpdateIDs = 1000
type markItemsReadRequest struct {
IDs []int64 `json:"ids" binding:"required"`
@@ -114,6 +115,10 @@ func (h *Handler) markItemsRead(c *gin.Context) {
badRequestError(c, "invalid request")
return
}
if len(req.IDs) == 0 || len(req.IDs) > maxBatchUpdateIDs {
badRequestError(c, "invalid ids")
return
}
if err := h.store.BatchUpdateItemsUnread(req.IDs, false); err != nil {
internalError(c, err, "mark items as read")
@@ -129,6 +134,10 @@ func (h *Handler) markItemsUnread(c *gin.Context) {
badRequestError(c, "invalid request")
return
}
if len(req.IDs) == 0 || len(req.IDs) > maxBatchUpdateIDs {
badRequestError(c, "invalid ids")
return
}
if err := h.store.BatchUpdateItemsUnread(req.IDs, true); err != nil {
internalError(c, err, "mark items as unread")
+2 -1
View File
@@ -2,12 +2,13 @@ package handler
import (
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func (h *Handler) search(c *gin.Context) {
q := c.Query("q")
q := strings.TrimSpace(c.Query("q"))
if q == "" {
badRequestError(c, "q parameter is required")
return
+44 -2
View File
@@ -10,6 +10,11 @@ import (
"github.com/google/uuid"
)
const (
sessionTTL = 30 * 24 * time.Hour
sessionSweepInterval = 60 * time.Second
)
type loginState struct {
windowStart int64
failures int
@@ -138,19 +143,56 @@ func (h *Handler) login(c *gin.Context) {
dataResponse(c, gin.H{"message": "logged in"})
}
func (h *Handler) isSessionValid(sessionID string) bool {
nowSec := time.Now().Unix()
h.mu.Lock()
defer h.mu.Unlock()
h.sweepExpiredSessionsLocked(nowSec)
expiresAt, ok := h.sessions[sessionID]
if !ok {
return false
}
if expiresAt <= nowSec {
delete(h.sessions, sessionID)
return false
}
return true
}
func (h *Handler) sweepExpiredSessionsLocked(nowSec int64) {
if nowSec-h.lastSweep < int64(sessionSweepInterval.Seconds()) {
return
}
h.lastSweep = nowSec
for sessionID, expiresAt := range h.sessions {
if expiresAt <= nowSec {
delete(h.sessions, sessionID)
}
}
}
// createSession generates a new session ID, stores it, and sets the session cookie.
func (h *Handler) createSession(c *gin.Context) {
now := time.Now()
expiresAt := now.Add(sessionTTL).Unix()
sessionID := uuid.New().String()
h.mu.Lock()
h.sessions[sessionID] = true
h.sweepExpiredSessionsLocked(now.Unix())
h.sessions[sessionID] = expiresAt
h.mu.Unlock()
http.SetCookie(c.Writer, &http.Cookie{
Name: "session",
Value: sessionID,
Path: "/",
MaxAge: 3600 * 24 * 30,
MaxAge: int(sessionTTL.Seconds()),
HttpOnly: true,
Secure: isSecureRequest(c.Request),
SameSite: http.SameSiteLaxMode,
+38 -5
View File
@@ -1,6 +1,8 @@
package httpc
import (
"context"
"net"
"net/http"
"net/url"
"strconv"
@@ -16,9 +18,9 @@ type clientPool struct {
var defaultClientPool = &clientPool{clients: make(map[string]*http.Client)}
// NewClient creates HTTP client with specified timeout and optional proxy.
// Clients are reused by (timeout, proxy) to keep connections warm.
func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
key := proxyURL + "|" + strconv.FormatInt(timeout.Milliseconds(), 10)
// Clients are reused by (timeout, proxy, allowPrivate) to keep connections warm.
func NewClient(timeout time.Duration, proxyURL string, allowPrivate bool) (*http.Client, error) {
key := proxyURL + "|" + strconv.FormatInt(timeout.Milliseconds(), 10) + "|" + strconv.FormatBool(allowPrivate)
defaultClientPool.mu.RLock()
if client, ok := defaultClientPool.clients[key]; ok {
@@ -35,6 +37,13 @@ func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: timeout,
}
dialer := &net.Dialer{Timeout: 10 * time.Second, KeepAlive: 30 * time.Second}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if err := validateDialTarget(ctx, addr, allowPrivate); err != nil {
return nil, err
}
return dialer.DialContext(ctx, network, addr)
}
if proxyURL != "" {
proxy, err := url.Parse(proxyURL)
@@ -45,8 +54,9 @@ func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
}
client := &http.Client{
Timeout: timeout,
Transport: transport,
Timeout: timeout,
Transport: transport,
CheckRedirect: redirectValidator(allowPrivate),
}
defaultClientPool.mu.Lock()
@@ -61,6 +71,29 @@ func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
return client, nil
}
func redirectValidator(allowPrivate bool) func(req *http.Request, via []*http.Request) error {
return func(req *http.Request, via []*http.Request) error {
if req == nil || req.URL == nil {
return nil
}
return ValidateRequestURL(req.Context(), req.URL.String(), allowPrivate)
}
}
func validateDialTarget(ctx context.Context, addr string, allowPrivate bool) error {
if allowPrivate {
return nil
}
host := addr
if parsedHost, _, err := net.SplitHostPort(addr); err == nil {
host = parsedHost
}
return validatePublicHost(ctx, host)
}
// SetDefaultHeaders adds default headers required for feed fetching.
func SetDefaultHeaders(req *http.Request) {
req.Header.Set("User-Agent", "fusion/1.0")
+74
View File
@@ -0,0 +1,74 @@
package httpc
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"time"
)
func ValidateRequestURL(ctx context.Context, rawURL string, allowPrivate bool) error {
parsed, err := url.ParseRequestURI(strings.TrimSpace(rawURL))
if err != nil {
return fmt.Errorf("invalid url: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return fmt.Errorf("unsupported url scheme: %s", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return fmt.Errorf("url host is required")
}
if !allowPrivate {
if err := validatePublicHost(ctx, host); err != nil {
return err
}
}
return nil
}
func validatePublicHost(ctx context.Context, host string) error {
if strings.EqualFold(host, "localhost") {
return fmt.Errorf("private host is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrLocalIP(ip) {
return fmt.Errorf("private host is not allowed")
}
return nil
}
resolveCtx := ctx
if resolveCtx == nil {
resolveCtx = context.Background()
}
lookupCtx, cancel := context.WithTimeout(resolveCtx, 2*time.Second)
defer cancel()
addrs, err := net.DefaultResolver.LookupIPAddr(lookupCtx, host)
if err != nil {
return fmt.Errorf("resolve host: %w", err)
}
if len(addrs) == 0 {
return fmt.Errorf("resolve host: no addresses")
}
for _, addr := range addrs {
if isPrivateOrLocalIP(addr.IP) {
return fmt.Errorf("private host is not allowed")
}
}
return nil
}
func isPrivateOrLocalIP(ip net.IP) bool {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() || ip.IsMulticast() || ip.IsInterfaceLocalMulticast() || ip.IsUnspecified()
}
+26 -10
View File
@@ -2,9 +2,12 @@ package pull
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@@ -24,8 +27,12 @@ type ParsedItem struct {
// FetchAndParse fetches RSS/Atom feed and parses into items.
// Returns parsed items and optional site URL discovered from feed metadata.
func FetchAndParse(ctx context.Context, feed *model.Feed, timeout time.Duration) ([]*ParsedItem, string, error) {
client, err := httpc.NewClient(timeout, feed.Proxy)
func FetchAndParse(ctx context.Context, feed *model.Feed, timeout time.Duration, allowPrivateFeeds bool) ([]*ParsedItem, string, error) {
if err := httpc.ValidateRequestURL(ctx, feed.Link, allowPrivateFeeds); err != nil {
return nil, "", fmt.Errorf("validate feed url: %w", err)
}
client, err := httpc.NewClient(timeout, feed.Proxy, allowPrivateFeeds)
if err != nil {
return nil, "", fmt.Errorf("create client: %w", err)
}
@@ -97,11 +104,6 @@ func normalizeSiteURL(raw string) string {
// - pub_date: prefer PublishedParsed, fallback to UpdatedParsed
// - link: convert to absolute URL
func mapItem(item *gofeed.Item, baseURL *url.URL) *ParsedItem {
guid := item.GUID
if guid == "" {
guid = item.Link
}
content := item.Content
if content == "" {
content = item.Description
@@ -116,13 +118,22 @@ func mapItem(item *gofeed.Item, baseURL *url.URL) *ParsedItem {
pubDate = time.Now().Unix()
}
link := item.Link
if baseURL != nil {
if absURL, err := baseURL.Parse(link); err == nil {
rawLink := strings.TrimSpace(item.Link)
link := rawLink
if rawLink != "" && baseURL != nil {
if absURL, err := baseURL.Parse(rawLink); err == nil {
link = absURL.String()
}
}
guid := strings.TrimSpace(item.GUID)
if guid == "" {
guid = strings.TrimSpace(link)
}
if guid == "" {
guid = fallbackGUID(item.Title, content, pubDate)
}
return &ParsedItem{
GUID: guid,
Title: item.Title,
@@ -131,3 +142,8 @@ func mapItem(item *gofeed.Item, baseURL *url.URL) *ParsedItem {
PubDate: pubDate,
}
}
func fallbackGUID(title, content string, pubDate int64) string {
h := sha256.Sum256([]byte(strings.TrimSpace(title) + "\n" + strings.TrimSpace(content) + "\n" + strconv.FormatInt(pubDate, 10)))
return "generated:" + hex.EncodeToString(h[:])
}
+29 -21
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/0x2E/fusion/internal/config"
@@ -64,28 +65,16 @@ func (p *Puller) pullAll(ctx context.Context) {
return
}
for _, feed := range feeds {
if ShouldSkip(feed, p.interval, p.maxBackoff) {
continue
}
// Acquire semaphore slot (blocks if at capacity)
if err := p.concurrency.Acquire(ctx, 1); err != nil {
return // Context cancelled
}
go func(f *model.Feed) {
defer p.concurrency.Release(1)
p.pullFeed(ctx, f)
}(feed)
}
_, _ = p.dispatchFeeds(ctx, feeds, func(feed *model.Feed) bool {
return !ShouldSkip(feed, p.interval, p.maxBackoff)
})
}
// pullFeed fetches single feed and saves new items.
func (p *Puller) pullFeed(ctx context.Context, feed *model.Feed) {
p.logger.Debug("pulling feed", "feed_id", feed.ID, "feed_name", feed.Name)
items, siteURL, err := FetchAndParse(ctx, feed, p.timeout)
items, siteURL, err := FetchAndParse(ctx, feed, p.timeout, p.config.AllowPrivateFeeds)
if err != nil {
p.logger.Warn("failed to fetch feed", "feed_id", feed.ID, "feed_name", feed.Name, "error", err)
if err := p.store.UpdateFeedFailure(feed.ID, err.Error()); err != nil {
@@ -125,7 +114,8 @@ func (p *Puller) pullFeed(ctx context.Context, feed *model.Feed) {
p.logger.Info("feed pulled successfully", "feed_id", feed.ID, "feed_name", feed.Name, "new_items", newCount)
}
// RefreshAll triggers refresh for all non-suspended feeds, bypassing backoff/interval skip logic.
// RefreshAll triggers refresh for all non-suspended feeds and waits until all
// started refresh jobs have completed. It bypasses backoff/interval skip logic.
// Concurrency is controlled by the same semaphore as periodic pulls.
func (p *Puller) RefreshAll(ctx context.Context) (int, error) {
feeds, err := p.store.ListFeeds()
@@ -133,24 +123,42 @@ func (p *Puller) RefreshAll(ctx context.Context) (int, error) {
return 0, fmt.Errorf("list feeds: %w", err)
}
count, err := p.dispatchFeeds(ctx, feeds, func(feed *model.Feed) bool {
return !feed.Suspended
})
if err != nil {
return count, err
}
return count, nil
}
func (p *Puller) dispatchFeeds(ctx context.Context, feeds []*model.Feed, shouldPull func(*model.Feed) bool) (int, error) {
count := 0
var wg sync.WaitGroup
var acquireErr error
for _, feed := range feeds {
if feed.Suspended {
if !shouldPull(feed) {
continue
}
count++
if err := p.concurrency.Acquire(ctx, 1); err != nil {
return count, err
acquireErr = err
break
}
count++
wg.Add(1)
go func(f *model.Feed) {
defer wg.Done()
defer p.concurrency.Release(1)
p.pullFeed(ctx, f)
}(feed)
}
return count, nil
wg.Wait()
return count, acquireErr
}
// RefreshFeed manually triggers refresh for specific feed (bypasses skip logic).
+15 -21
View File
@@ -267,39 +267,24 @@ func (s *Store) BatchCreateFeeds(inputs []BatchCreateFeedsInput) (*BatchCreateFe
}
defer tx.Rollback()
// Check existing links to avoid duplicates
existingLinks := make(map[string]bool)
rows, err := tx.Query(`SELECT link FROM feeds`)
if err != nil {
return nil, err
}
for rows.Next() {
var link string
if err := rows.Scan(&link); err != nil {
rows.Close()
return nil, err
}
existingLinks[link] = true
}
rows.Close()
if err := rows.Err(); err != nil {
return nil, err
}
stmt, err := tx.Prepare(`
INSERT INTO feeds (group_id, name, link, site_url, proxy)
VALUES (:group_id, :name, :link, :site_url, '')
ON CONFLICT(link) DO NOTHING
`)
if err != nil {
return nil, err
}
defer stmt.Close()
seenLinks := make(map[string]bool, len(inputs))
for _, input := range inputs {
if existingLinks[input.Link] {
if seenLinks[input.Link] {
result.Errors = append(result.Errors, fmt.Sprintf("duplicate feed: %s", input.Link))
continue
}
seenLinks[input.Link] = true
res, err := stmt.Exec(
sql.Named("group_id", input.GroupID),
@@ -312,13 +297,22 @@ func (s *Store) BatchCreateFeeds(inputs []BatchCreateFeedsInput) (*BatchCreateFe
continue
}
affected, err := res.RowsAffected()
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to inspect result for %s: %v", input.Link, err))
continue
}
if affected == 0 {
result.Errors = append(result.Errors, fmt.Sprintf("duplicate feed: %s", input.Link))
continue
}
id, err := res.LastInsertId()
if err != nil {
result.Errors = append(result.Errors, fmt.Sprintf("failed to get id for %s: %v", input.Link, err))
continue
}
existingLinks[input.Link] = true
result.Created++
result.CreatedIDs = append(result.CreatedIDs, id)
}
+72 -3
View File
@@ -199,13 +199,33 @@ func (s *Store) UpdateItemUnread(id int64, unread bool) error {
return nil
}
// BatchUpdateItemsUnread marks multiple items as read/unread in a single query.
// Dynamically builds IN clause with named parameters (:id0, :id1, ...) for safety.
// BatchUpdateItemsUnread marks multiple items as read/unread.
// IDs are chunked to keep SQL statements bounded and avoid oversized IN clauses.
func (s *Store) BatchUpdateItemsUnread(ids []int64, unread bool) error {
if len(ids) == 0 {
return nil
}
const chunkSize = 500
for start := 0; start < len(ids); start += chunkSize {
end := start + chunkSize
if end > len(ids) {
end = len(ids)
}
if err := s.batchUpdateItemsUnreadChunk(ids[start:end], unread); err != nil {
return err
}
}
return nil
}
func (s *Store) batchUpdateItemsUnreadChunk(ids []int64, unread bool) error {
if len(ids) == 0 {
return nil
}
placeholders := make([]string, len(ids))
args := make([]interface{}, 0, len(ids)+1)
args = append(args, sql.Named("unread", boolToInt(unread)))
@@ -246,11 +266,60 @@ type SearchItemResult struct {
}
func (s *Store) SearchItems(query string, limit int) ([]*SearchItemResult, error) {
ftsQuery := buildFTSQuery(query)
if ftsQuery == "" {
return s.searchItemsLike(query, limit)
}
rows, err := s.db.Query(`
SELECT i.id, i.feed_id, i.title, i.pub_date
FROM items_fts
INNER JOIN items i ON i.id = items_fts.rowid
WHERE items_fts MATCH :query
ORDER BY i.pub_date DESC, i.id DESC
LIMIT :limit
`, sql.Named("query", ftsQuery), sql.Named("limit", limit))
if err != nil {
return s.searchItemsLike(query, limit)
}
defer rows.Close()
items := []*SearchItemResult{}
for rows.Next() {
i := &SearchItemResult{}
if err := rows.Scan(&i.ID, &i.FeedID, &i.Title, &i.PubDate); err != nil {
return nil, err
}
items = append(items, i)
}
return items, rows.Err()
}
func buildFTSQuery(query string) string {
parts := strings.Fields(strings.TrimSpace(query))
if len(parts) == 0 {
return ""
}
terms := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
part = strings.ReplaceAll(part, `"`, `""`)
terms = append(terms, `"`+part+`"*`)
}
return strings.Join(terms, " AND ")
}
func (s *Store) searchItemsLike(query string, limit int) ([]*SearchItemResult, error) {
rows, err := s.db.Query(`
SELECT id, feed_id, title, pub_date
FROM items
WHERE title LIKE :query OR content LIKE :query
ORDER BY pub_date DESC
ORDER BY pub_date DESC, id DESC
LIMIT :limit
`, sql.Named("query", "%"+query+"%"), sql.Named("limit", limit))
if err != nil {
@@ -16,6 +16,7 @@ CREATE TABLE IF NOT EXISTS feeds (
link TEXT NOT NULL UNIQUE,
site_url TEXT DEFAULT '',
last_build INTEGER DEFAULT 0,
last_failure_at INTEGER NOT NULL DEFAULT 0,
failure TEXT DEFAULT '',
failures INTEGER DEFAULT 0,
suspended INTEGER DEFAULT 0,
@@ -42,6 +43,33 @@ CREATE TABLE IF NOT EXISTS items (
CREATE UNIQUE INDEX IF NOT EXISTS idx_items_feed_guid ON items(feed_id, guid);
CREATE INDEX IF NOT EXISTS idx_items_unread ON items(unread) WHERE unread = 1;
CREATE INDEX IF NOT EXISTS idx_items_pub_date ON items(pub_date DESC);
CREATE INDEX IF NOT EXISTS idx_items_feed_unread ON items(feed_id, unread);
CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5(
title,
content,
tokenize = 'unicode61'
);
INSERT INTO items_fts(rowid, title, content)
SELECT id, title, content
FROM items
WHERE id NOT IN (SELECT rowid FROM items_fts);
CREATE TRIGGER IF NOT EXISTS items_fts_items_ai AFTER INSERT ON items BEGIN
INSERT INTO items_fts(rowid, title, content)
VALUES (new.id, new.title, new.content);
END;
CREATE TRIGGER IF NOT EXISTS items_fts_items_ad AFTER DELETE ON items BEGIN
DELETE FROM items_fts WHERE rowid = old.id;
END;
CREATE TRIGGER IF NOT EXISTS items_fts_items_au AFTER UPDATE ON items BEGIN
DELETE FROM items_fts WHERE rowid = old.id;
INSERT INTO items_fts(rowid, title, content)
VALUES (new.id, new.title, new.content);
END;
CREATE TABLE IF NOT EXISTS bookmarks (
@@ -56,4 +84,3 @@ CREATE TABLE IF NOT EXISTS bookmarks (
);
CREATE INDEX IF NOT EXISTS idx_bookmarks_created_at ON bookmarks(created_at DESC);
@@ -1 +0,0 @@
ALTER TABLE feeds ADD COLUMN last_failure_at INTEGER NOT NULL DEFAULT 0;
@@ -1 +0,0 @@
CREATE INDEX IF NOT EXISTS idx_items_feed_unread ON items(feed_id, unread);