mirror of
https://github.com/0x2E/fusion.git
synced 2026-05-19 18:30:35 +00:00
harden auth config and optimize feed pulling
This commit is contained in:
+12
-2
@@ -2,7 +2,9 @@
|
||||
FUSION_DB_PATH=fusion.db
|
||||
|
||||
# Authentication
|
||||
FUSION_PASSWORD=admin
|
||||
FUSION_PASSWORD=changeme
|
||||
# Explicitly allow empty password (not recommended)
|
||||
# FUSION_ALLOW_EMPTY_PASSWORD=false
|
||||
|
||||
# Server Configuration
|
||||
FUSION_PORT=8080
|
||||
@@ -20,6 +22,14 @@ FUSION_PULL_CONCURRENCY=10
|
||||
# Maximum backoff time in seconds (default: 604800 = 7 days)
|
||||
FUSION_PULL_MAX_BACKOFF=604800
|
||||
|
||||
# Login rate limiting
|
||||
# Max failed attempts per window (default: 10)
|
||||
FUSION_LOGIN_RATE_LIMIT=10
|
||||
# Window size in seconds (default: 60)
|
||||
FUSION_LOGIN_WINDOW=60
|
||||
# Block duration in seconds after limit is exceeded (default: 300)
|
||||
FUSION_LOGIN_BLOCK=300
|
||||
|
||||
# Logging Configuration
|
||||
# Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
|
||||
FUSION_LOG_LEVEL=INFO
|
||||
@@ -35,7 +45,7 @@ FUSION_LOG_FORMAT=auto
|
||||
# FUSION_OIDC_CLIENT_ID=
|
||||
# FUSION_OIDC_CLIENT_SECRET=
|
||||
|
||||
# Callback URL (default: auto-detected from Host header as {scheme}://{host}/api/oidc/callback)
|
||||
# Callback URL (required when OIDC issuer is configured)
|
||||
# FUSION_OIDC_REDIRECT_URI=
|
||||
|
||||
# Restrict login to a specific user identity (email or subject claim, optional)
|
||||
|
||||
+11
-1
@@ -2,7 +2,9 @@
|
||||
FUSION_DB_PATH=fusion.db
|
||||
|
||||
# Authentication
|
||||
FUSION_PASSWORD=admin
|
||||
FUSION_PASSWORD=changeme
|
||||
# Explicitly allow empty password (not recommended)
|
||||
# FUSION_ALLOW_EMPTY_PASSWORD=false
|
||||
|
||||
# Server Configuration
|
||||
FUSION_PORT=8080
|
||||
@@ -20,6 +22,14 @@ FUSION_PULL_CONCURRENCY=10
|
||||
# Maximum backoff time in seconds (default: 604800 = 7 days)
|
||||
FUSION_PULL_MAX_BACKOFF=604800
|
||||
|
||||
# Login rate limiting
|
||||
# Max failed attempts per window (default: 10)
|
||||
FUSION_LOGIN_RATE_LIMIT=10
|
||||
# Window size in seconds (default: 60)
|
||||
FUSION_LOGIN_WINDOW=60
|
||||
# Block duration in seconds after limit is exceeded (default: 300)
|
||||
FUSION_LOGIN_BLOCK=300
|
||||
|
||||
# Logging Configuration
|
||||
# Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
|
||||
FUSION_LOG_LEVEL=INFO
|
||||
|
||||
@@ -27,7 +27,10 @@ func main() {
|
||||
}
|
||||
|
||||
func run() error {
|
||||
cfg := config.Load()
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setupLogger(cfg)
|
||||
|
||||
st, err := store.New(cfg.DBPath)
|
||||
|
||||
@@ -84,16 +84,6 @@ func (a *OIDCAuthenticator) AuthURL() (authURL string, err error) {
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// SetRedirectURI sets the redirect URI dynamically (for auto-detection from Host header).
|
||||
func (a *OIDCAuthenticator) SetRedirectURI(uri string) {
|
||||
a.oauth2Config.RedirectURL = uri
|
||||
}
|
||||
|
||||
// RedirectURI returns the currently configured redirect URI.
|
||||
func (a *OIDCAuthenticator) RedirectURI() string {
|
||||
return a.oauth2Config.RedirectURL
|
||||
}
|
||||
|
||||
// Callback exchanges the authorization code for tokens and verifies the ID token.
|
||||
// Returns a user identifier (email or subject claim).
|
||||
func (a *OIDCAuthenticator) Callback(ctx context.Context, state, code string) (string, error) {
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DBPath string
|
||||
Password string // Plaintext password from env
|
||||
Host string // TODO parse and use
|
||||
Port int
|
||||
|
||||
PullInterval int // Pull interval in seconds (default: 1800 = 30 min)
|
||||
@@ -16,6 +17,10 @@ type Config struct {
|
||||
PullConcurrency int // Max concurrent pulls (default: 10)
|
||||
PullMaxBackoff int // Max backoff time in seconds (default: 604800 = 7 days)
|
||||
|
||||
LoginRateLimit int // Max failed login attempts per window (default: 10)
|
||||
LoginWindow int // Login rate limit window in seconds (default: 60)
|
||||
LoginBlock int // Login block duration in seconds (default: 300)
|
||||
|
||||
LogLevel string // Log level: DEBUG, INFO, WARN, ERROR (default: INFO)
|
||||
LogFormat string // Log format: text, json, auto (default: auto)
|
||||
|
||||
@@ -23,11 +28,11 @@ type Config struct {
|
||||
OIDCIssuer string // OIDC provider URL
|
||||
OIDCClientID string // OAuth2 client ID
|
||||
OIDCClientSecret string // OAuth2 client secret
|
||||
OIDCRedirectURI string // Callback URL (default: auto-detect from Host header)
|
||||
OIDCRedirectURI string // Callback URL (required when OIDC is enabled)
|
||||
OIDCAllowedUser string // Optional: restrict to specific user identity (email or sub)
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
func Load() (*Config, error) {
|
||||
// Backward compatible env vars:
|
||||
// - DB (legacy) -> FUSION_DB_PATH
|
||||
// - PASSWORD (legacy) -> FUSION_PASSWORD
|
||||
@@ -44,8 +49,14 @@ func Load() *Config {
|
||||
if password == "" {
|
||||
password = os.Getenv("PASSWORD")
|
||||
}
|
||||
if password == "" {
|
||||
password = "admin" // TODO allow empty password
|
||||
|
||||
allowEmptyPassword, err := getEnvBool("FUSION_ALLOW_EMPTY_PASSWORD", false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.TrimSpace(password) == "" && !allowEmptyPassword {
|
||||
return nil, fmt.Errorf("FUSION_PASSWORD is required (or set FUSION_ALLOW_EMPTY_PASSWORD=true)")
|
||||
}
|
||||
|
||||
port := os.Getenv("FUSION_PORT")
|
||||
@@ -57,7 +68,40 @@ func Load() *Config {
|
||||
}
|
||||
parsedPort, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, fmt.Errorf("invalid FUSION_PORT: %w", err)
|
||||
}
|
||||
if parsedPort <= 0 || parsedPort > 65535 {
|
||||
return nil, fmt.Errorf("invalid FUSION_PORT: must be in range 1-65535")
|
||||
}
|
||||
|
||||
pullInterval, err := getEnvInt("FUSION_PULL_INTERVAL", 1800, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pullTimeout, err := getEnvInt("FUSION_PULL_TIMEOUT", 30, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pullConcurrency, err := getEnvInt("FUSION_PULL_CONCURRENCY", 10, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pullMaxBackoff, err := getEnvInt("FUSION_PULL_MAX_BACKOFF", 604800, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRateLimit, err := getEnvInt("FUSION_LOGIN_RATE_LIMIT", 10, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loginWindow, err := getEnvInt("FUSION_LOGIN_WINDOW", 60, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loginBlock, err := getEnvInt("FUSION_LOGIN_BLOCK", 300, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logLevel := os.Getenv("FUSION_LOG_LEVEL")
|
||||
@@ -74,10 +118,13 @@ func Load() *Config {
|
||||
DBPath: dbPath,
|
||||
Password: password,
|
||||
Port: parsedPort,
|
||||
PullInterval: getEnvInt("FUSION_PULL_INTERVAL", 1800),
|
||||
PullTimeout: getEnvInt("FUSION_PULL_TIMEOUT", 30),
|
||||
PullConcurrency: getEnvInt("FUSION_PULL_CONCURRENCY", 10),
|
||||
PullMaxBackoff: getEnvInt("FUSION_PULL_MAX_BACKOFF", 604800),
|
||||
PullInterval: pullInterval,
|
||||
PullTimeout: pullTimeout,
|
||||
PullConcurrency: pullConcurrency,
|
||||
PullMaxBackoff: pullMaxBackoff,
|
||||
LoginRateLimit: loginRateLimit,
|
||||
LoginWindow: loginWindow,
|
||||
LoginBlock: loginBlock,
|
||||
LogLevel: logLevel,
|
||||
LogFormat: logFormat,
|
||||
|
||||
@@ -86,17 +133,32 @@ func Load() *Config {
|
||||
OIDCClientSecret: os.Getenv("FUSION_OIDC_CLIENT_SECRET"),
|
||||
OIDCRedirectURI: os.Getenv("FUSION_OIDC_REDIRECT_URI"),
|
||||
OIDCAllowedUser: os.Getenv("FUSION_OIDC_ALLOWED_USER"),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultVal int) int {
|
||||
func getEnvInt(key string, defaultVal, minVal int) (int, error) {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
return defaultVal
|
||||
return defaultVal, nil
|
||||
}
|
||||
parsed, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
return 0, fmt.Errorf("invalid %s: %w", key, err)
|
||||
}
|
||||
return parsed
|
||||
if parsed < minVal {
|
||||
return 0, fmt.Errorf("invalid %s: must be >= %d", key, minVal)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultVal bool) (bool, error) {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
return defaultVal, nil
|
||||
}
|
||||
parsed, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid %s: %w", key, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
@@ -23,16 +23,19 @@ func (h *Handler) listBookmarks(c *gin.Context) {
|
||||
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
val, err := strconv.Atoi(limitStr)
|
||||
if err != nil {
|
||||
if err != nil || val <= 0 {
|
||||
badRequestError(c, "invalid limit")
|
||||
return
|
||||
}
|
||||
if val > maxListLimit {
|
||||
val = maxListLimit
|
||||
}
|
||||
limit = val
|
||||
}
|
||||
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
val, err := strconv.Atoi(offsetStr)
|
||||
if err != nil {
|
||||
if err != nil || val < 0 {
|
||||
badRequestError(c, "invalid offset")
|
||||
return
|
||||
}
|
||||
@@ -45,7 +48,13 @@ func (h *Handler) listBookmarks(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
listResponse(c, bookmarks, len(bookmarks))
|
||||
total, err := h.store.CountBookmarks()
|
||||
if err != nil {
|
||||
internalError(c, err, "count bookmarks")
|
||||
return
|
||||
}
|
||||
|
||||
listResponse(c, bookmarks, total)
|
||||
}
|
||||
|
||||
func (h *Handler) getBookmark(c *gin.Context) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -31,3 +32,11 @@ func badRequestError(c *gin.Context, message string) {
|
||||
func unauthorizedError(c *gin.Context) {
|
||||
c.JSON(401, gin.H{"error": "unauthorized"})
|
||||
}
|
||||
|
||||
// tooManyRequestsError returns 429 and sets Retry-After when available.
|
||||
func tooManyRequestsError(c *gin.Context, retryAfterSec int64) {
|
||||
if retryAfterSec > 0 {
|
||||
c.Header("Retry-After", strconv.FormatInt(retryAfterSec, 10))
|
||||
}
|
||||
c.JSON(429, gin.H{"error": "too many login attempts"})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/0x2E/fusion/internal/auth"
|
||||
@@ -23,6 +24,7 @@ type Handler struct {
|
||||
sessions map[string]bool // sessionID -> valid, in-memory session store
|
||||
mu sync.RWMutex // protects sessions map
|
||||
oidcAuth *auth.OIDCAuthenticator // nil when OIDC is disabled
|
||||
limiter *loginLimiter
|
||||
}
|
||||
|
||||
func New(store *store.Store, config *config.Config, puller interface {
|
||||
@@ -41,9 +43,14 @@ func New(store *store.Store, config *config.Config, puller interface {
|
||||
passwordHash: passwordHash,
|
||||
puller: puller,
|
||||
sessions: make(map[string]bool),
|
||||
limiter: newLoginLimiter(config.LoginRateLimit, config.LoginWindow, config.LoginBlock),
|
||||
}
|
||||
|
||||
if config.OIDCIssuer != "" {
|
||||
if strings.TrimSpace(config.OIDCRedirectURI) == "" {
|
||||
return nil, fmt.Errorf("FUSION_OIDC_REDIRECT_URI is required when OIDC is enabled")
|
||||
}
|
||||
|
||||
oidcAuth, err := auth.NewOIDC(
|
||||
context.Background(),
|
||||
config.OIDCIssuer,
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const maxListLimit = 100
|
||||
|
||||
type markItemsReadRequest struct {
|
||||
IDs []int64 `json:"ids" binding:"required"`
|
||||
}
|
||||
@@ -44,10 +46,13 @@ func (h *Handler) listItems(c *gin.Context) {
|
||||
|
||||
if limit := c.Query("limit"); limit != "" {
|
||||
val, err := strconv.Atoi(limit)
|
||||
if err != nil {
|
||||
if err != nil || val <= 0 {
|
||||
badRequestError(c, "invalid limit")
|
||||
return
|
||||
}
|
||||
if val > maxListLimit {
|
||||
val = maxListLimit
|
||||
}
|
||||
params.Limit = val
|
||||
} else {
|
||||
params.Limit = 10
|
||||
@@ -55,7 +60,7 @@ func (h *Handler) listItems(c *gin.Context) {
|
||||
|
||||
if offset := c.Query("offset"); offset != "" {
|
||||
val, err := strconv.Atoi(offset)
|
||||
if err != nil {
|
||||
if err != nil || val < 0 {
|
||||
badRequestError(c, "invalid offset")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
@@ -18,15 +17,6 @@ func (h *Handler) oidcLogin(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-detect redirect URI from Host header if not explicitly configured
|
||||
if h.oidcAuth.RedirectURI() == "" {
|
||||
scheme := "https"
|
||||
if !isSecureRequest(c.Request) {
|
||||
scheme = "http"
|
||||
}
|
||||
h.oidcAuth.SetRedirectURI(fmt.Sprintf("%s://%s/api/oidc/callback", scheme, c.Request.Host))
|
||||
}
|
||||
|
||||
authURL, err := h.oidcAuth.AuthURL()
|
||||
if err != nil {
|
||||
internalError(c, err, "oidc auth url")
|
||||
|
||||
@@ -20,18 +20,21 @@ func (h *Handler) search(c *gin.Context) {
|
||||
badRequestError(c, "invalid limit")
|
||||
return
|
||||
}
|
||||
if parsed > maxListLimit {
|
||||
parsed = maxListLimit
|
||||
}
|
||||
limit = parsed
|
||||
}
|
||||
|
||||
feeds, err := h.store.SearchFeeds(q)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": "search feeds: " + err.Error()})
|
||||
internalError(c, err, "search feeds")
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.store.SearchItems(q, limit)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": "search items: " + err.Error()})
|
||||
internalError(c, err, "search items")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,102 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/0x2E/fusion/internal/auth"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type loginState struct {
|
||||
windowStart int64
|
||||
failures int
|
||||
blockedTill int64
|
||||
}
|
||||
|
||||
type loginLimiter struct {
|
||||
mu sync.Mutex
|
||||
states map[string]loginState
|
||||
limit int
|
||||
windowSecs int64
|
||||
blockSecs int64
|
||||
lastSweepSec int64
|
||||
}
|
||||
|
||||
func newLoginLimiter(limit, windowSecs, blockSecs int) *loginLimiter {
|
||||
return &loginLimiter{
|
||||
states: make(map[string]loginState),
|
||||
limit: limit,
|
||||
windowSecs: int64(windowSecs),
|
||||
blockSecs: int64(blockSecs),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loginLimiter) allow(ip string, now time.Time) (bool, int64) {
|
||||
nowSec := now.Unix()
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.sweep(nowSec)
|
||||
|
||||
state, ok := l.states[ip]
|
||||
if !ok {
|
||||
return true, 0
|
||||
}
|
||||
if state.blockedTill > nowSec {
|
||||
return false, state.blockedTill - nowSec
|
||||
}
|
||||
|
||||
return true, 0
|
||||
}
|
||||
|
||||
func (l *loginLimiter) recordFailure(ip string, now time.Time) {
|
||||
nowSec := now.Unix()
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.sweep(nowSec)
|
||||
|
||||
state := l.states[ip]
|
||||
if state.windowStart == 0 || nowSec-state.windowStart >= l.windowSecs {
|
||||
state.windowStart = nowSec
|
||||
state.failures = 0
|
||||
}
|
||||
|
||||
state.failures++
|
||||
if state.failures >= l.limit {
|
||||
state.blockedTill = nowSec + l.blockSecs
|
||||
state.windowStart = nowSec
|
||||
state.failures = 0
|
||||
}
|
||||
|
||||
l.states[ip] = state
|
||||
}
|
||||
|
||||
func (l *loginLimiter) recordSuccess(ip string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
delete(l.states, ip)
|
||||
}
|
||||
|
||||
func (l *loginLimiter) sweep(nowSec int64) {
|
||||
if nowSec-l.lastSweepSec < 60 {
|
||||
return
|
||||
}
|
||||
l.lastSweepSec = nowSec
|
||||
|
||||
for ip, state := range l.states {
|
||||
windowExpired := state.windowStart > 0 && nowSec-state.windowStart >= l.windowSecs
|
||||
unblocked := state.blockedTill > 0 && state.blockedTill <= nowSec
|
||||
if (state.blockedTill == 0 && windowExpired) || unblocked {
|
||||
delete(l.states, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isSecureRequest(r *http.Request) bool {
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
@@ -16,20 +106,33 @@ func isSecureRequest(r *http.Request) bool {
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
Password *string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *Handler) login(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
allowed, retryAfter := h.limiter.allow(ip, time.Now())
|
||||
if !allowed {
|
||||
tooManyRequestsError(c, retryAfter)
|
||||
return
|
||||
}
|
||||
|
||||
var req loginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
badRequestError(c, "invalid request")
|
||||
return
|
||||
}
|
||||
if req.Password == nil {
|
||||
badRequestError(c, "invalid request")
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.CheckPassword(h.passwordHash, req.Password); err != nil {
|
||||
if err := auth.CheckPassword(h.passwordHash, *req.Password); err != nil {
|
||||
h.limiter.recordFailure(ip, time.Now())
|
||||
unauthorizedError(c)
|
||||
return
|
||||
}
|
||||
h.limiter.recordSuccess(ip)
|
||||
|
||||
h.createSession(c)
|
||||
dataResponse(c, gin.H{"message": "logged in"})
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package handler
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLoginLimiterSweepDeletesExpiredWindowWithoutBlock(t *testing.T) {
|
||||
limiter := newLoginLimiter(3, 10, 30)
|
||||
limiter.states["1.1.1.1"] = loginState{windowStart: 10, failures: 1}
|
||||
|
||||
limiter.sweep(120)
|
||||
|
||||
if _, ok := limiter.states["1.1.1.1"]; ok {
|
||||
t.Fatal("expected expired non-blocked state to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLimiterSweepDeletesUnblockedState(t *testing.T) {
|
||||
limiter := newLoginLimiter(3, 60, 30)
|
||||
limiter.states["2.2.2.2"] = loginState{windowStart: 100, blockedTill: 110}
|
||||
|
||||
limiter.sweep(120)
|
||||
|
||||
if _, ok := limiter.states["2.2.2.2"]; ok {
|
||||
t.Fatal("expected unblocked state to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLimiterSweepKeepsActiveBlockedState(t *testing.T) {
|
||||
limiter := newLoginLimiter(3, 60, 30)
|
||||
limiter.states["3.3.3.3"] = loginState{windowStart: 100, blockedTill: 170}
|
||||
|
||||
limiter.sweep(120)
|
||||
|
||||
if _, ok := limiter.states["3.3.3.3"]; !ok {
|
||||
t.Fatal("expected active blocked state to remain")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/0x2E/fusion/internal/auth"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func newTestSessionHandler(t *testing.T, password string) *Handler {
|
||||
t.Helper()
|
||||
|
||||
hash, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("hash password: %v", err)
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
passwordHash: hash,
|
||||
sessions: make(map[string]bool),
|
||||
limiter: newLoginLimiter(10, 60, 300),
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginRejectsMissingPasswordField(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := newTestSessionHandler(t, "secret")
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/api/sessions", h.login)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions", strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginAcceptsEmptyPassword(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := newTestSessionHandler(t, "")
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/api/sessions", h.login)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/sessions", strings.NewReader(`{"password":""}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
if cookie := w.Header().Get("Set-Cookie"); !strings.Contains(cookie, "session=") {
|
||||
t.Fatalf("expected session cookie to be set, got %q", cookie)
|
||||
}
|
||||
}
|
||||
@@ -16,12 +16,14 @@ type Feed struct {
|
||||
Link string `json:"link"`
|
||||
SiteURL string `json:"site_url,omitempty"`
|
||||
LastBuild int64 `json:"last_build"`
|
||||
Failure string `json:"failure,omitempty"`
|
||||
Failures int64 `json:"failures"`
|
||||
Suspended bool `json:"suspended"`
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
// LastFailureAt is Unix timestamp of the most recent pull failure.
|
||||
LastFailureAt int64 `json:"last_failure_at"`
|
||||
Failure string `json:"failure,omitempty"`
|
||||
Failures int64 `json:"failures"`
|
||||
Suspended bool `json:"suspended"`
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
|
||||
UnreadCount int64 `json:"unread_count"`
|
||||
ItemCount int64 `json:"item_count"`
|
||||
|
||||
@@ -3,15 +3,37 @@ package httpc
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type clientPool struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*http.Client
|
||||
}
|
||||
|
||||
var defaultClientPool = &clientPool{clients: make(map[string]*http.Client)}
|
||||
|
||||
// NewClient creates HTTP client with specified timeout and optional proxy.
|
||||
// Returns client configured for HTTP/2 with keep-alives disabled.
|
||||
// Clients are reused by (timeout, proxy) to keep connections warm.
|
||||
func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
|
||||
key := proxyURL + "|" + strconv.FormatInt(timeout.Milliseconds(), 10)
|
||||
|
||||
defaultClientPool.mu.RLock()
|
||||
if client, ok := defaultClientPool.clients[key]; ok {
|
||||
defaultClientPool.mu.RUnlock()
|
||||
return client, nil
|
||||
}
|
||||
defaultClientPool.mu.RUnlock()
|
||||
|
||||
transport := &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
ForceAttemptHTTP2: true,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 128,
|
||||
MaxIdleConnsPerHost: 16,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: timeout,
|
||||
}
|
||||
|
||||
if proxyURL != "" {
|
||||
@@ -22,14 +44,24 @@ func NewClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
|
||||
transport.Proxy = http.ProxyURL(proxy)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
defaultClientPool.mu.Lock()
|
||||
if existing, ok := defaultClientPool.clients[key]; ok {
|
||||
defaultClientPool.mu.Unlock()
|
||||
transport.CloseIdleConnections()
|
||||
return existing, nil
|
||||
}
|
||||
defaultClientPool.clients[key] = client
|
||||
defaultClientPool.mu.Unlock()
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// SetDefaultHeaders adds default headers required for feed fetching.
|
||||
func SetDefaultHeaders(req *http.Request) {
|
||||
req.Header.Set("User-Agent", "fusion/1.0")
|
||||
req.Header.Set("Connection", "close")
|
||||
}
|
||||
|
||||
@@ -38,7 +38,11 @@ func ShouldSkip(feed *model.Feed, interval, maxBackoff time.Duration) bool {
|
||||
// Skip if in backoff period
|
||||
if feed.Failures > 0 {
|
||||
backoff := CalculateBackoff(interval, feed.Failures, maxBackoff)
|
||||
nextPull := feed.LastBuild + int64(backoff.Seconds())
|
||||
base := feed.LastFailureAt
|
||||
if base <= 0 {
|
||||
base = feed.LastBuild
|
||||
}
|
||||
nextPull := base + int64(backoff.Seconds())
|
||||
if now < nextPull {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -94,23 +94,21 @@ func (p *Puller) pullFeed(ctx context.Context, feed *model.Feed) {
|
||||
return
|
||||
}
|
||||
|
||||
newCount := 0
|
||||
inputs := make([]store.BatchCreateItemInput, 0, len(items))
|
||||
for _, item := range items {
|
||||
exists, err := p.store.ItemExists(feed.ID, item.GUID)
|
||||
if err != nil {
|
||||
p.logger.Error("failed to check item existence", "feed_id", feed.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
inputs = append(inputs, store.BatchCreateItemInput{
|
||||
GUID: item.GUID,
|
||||
Title: item.Title,
|
||||
Link: item.Link,
|
||||
Content: item.Content,
|
||||
PubDate: item.PubDate,
|
||||
})
|
||||
}
|
||||
|
||||
_, err = p.store.CreateItem(feed.ID, item.GUID, item.Title, item.Link, item.Content, item.PubDate)
|
||||
if err != nil {
|
||||
p.logger.Error("failed to create item", "feed_id", feed.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
newCount++
|
||||
newCount, err := p.store.BatchCreateItemsIgnore(feed.ID, inputs)
|
||||
if err != nil {
|
||||
p.logger.Error("failed to batch create items", "feed_id", feed.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.store.UpdateFeedLastBuild(feed.ID, time.Now().Unix()); err != nil {
|
||||
|
||||
@@ -98,3 +98,9 @@ func (s *Store) BookmarkExists(link string) (bool, error) {
|
||||
err := s.db.QueryRow(`SELECT EXISTS(SELECT 1 FROM bookmarks WHERE link = :link)`, sql.Named("link", link)).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (s *Store) CountBookmarks() (int, error) {
|
||||
var count int
|
||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM bookmarks`).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -11,11 +11,14 @@ import (
|
||||
|
||||
func (s *Store) ListFeeds() ([]*model.Feed, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT f.id, f.group_id, f.name, f.link, f.site_url, f.last_build,
|
||||
SELECT f.id, f.group_id, f.name, f.link, f.site_url, f.last_build, f.last_failure_at,
|
||||
f.failure, f.failures, f.suspended, f.proxy, f.created_at, f.updated_at,
|
||||
(SELECT COUNT(*) FROM items WHERE feed_id = f.id AND unread = 1) AS unread_count,
|
||||
(SELECT COUNT(*) FROM items WHERE feed_id = f.id) AS item_count
|
||||
COALESCE(SUM(CASE WHEN i.unread = 1 THEN 1 ELSE 0 END), 0) AS unread_count,
|
||||
COALESCE(COUNT(i.id), 0) AS item_count
|
||||
FROM feeds f
|
||||
LEFT JOIN items i ON i.feed_id = f.id
|
||||
GROUP BY f.id, f.group_id, f.name, f.link, f.site_url, f.last_build, f.last_failure_at,
|
||||
f.failure, f.failures, f.suspended, f.proxy, f.created_at, f.updated_at
|
||||
ORDER BY f.id
|
||||
`)
|
||||
if err != nil {
|
||||
@@ -27,7 +30,7 @@ func (s *Store) ListFeeds() ([]*model.Feed, error) {
|
||||
for rows.Next() {
|
||||
f := &model.Feed{}
|
||||
var suspended int
|
||||
if err := rows.Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt, &f.UnreadCount, &f.ItemCount); err != nil {
|
||||
if err := rows.Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.LastFailureAt, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt, &f.UnreadCount, &f.ItemCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.Suspended = intToBool(suspended)
|
||||
@@ -40,10 +43,10 @@ func (s *Store) GetFeed(id int64) (*model.Feed, error) {
|
||||
f := &model.Feed{}
|
||||
var suspended int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, group_id, name, link, site_url, last_build, failure, failures, suspended, proxy, created_at, updated_at
|
||||
SELECT id, group_id, name, link, site_url, last_build, last_failure_at, failure, failures, suspended, proxy, created_at, updated_at
|
||||
FROM feeds
|
||||
WHERE id = :id
|
||||
`, sql.Named("id", id)).Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt)
|
||||
`, sql.Named("id", id)).Scan(&f.ID, &f.GroupID, &f.Name, &f.Link, &f.SiteURL, &f.LastBuild, &f.LastFailureAt, &f.Failure, &f.Failures, &suspended, &f.Proxy, &f.CreatedAt, &f.UpdatedAt)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, fmt.Errorf("%w: feed", ErrNotFound)
|
||||
@@ -205,7 +208,7 @@ func (s *Store) DeleteFeed(id int64) error {
|
||||
func (s *Store) UpdateFeedLastBuild(id int64, lastBuild int64) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE feeds
|
||||
SET last_build = :last_build, failures = 0, failure = '', updated_at = unixepoch()
|
||||
SET last_build = :last_build, last_failure_at = 0, failures = 0, failure = '', updated_at = unixepoch()
|
||||
WHERE id = :id
|
||||
`, sql.Named("last_build", lastBuild), sql.Named("id", id))
|
||||
return err
|
||||
@@ -214,7 +217,7 @@ func (s *Store) UpdateFeedLastBuild(id int64, lastBuild int64) error {
|
||||
func (s *Store) UpdateFeedFailure(id int64, failure string) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE feeds
|
||||
SET failures = failures + 1, failure = :failure, updated_at = unixepoch()
|
||||
SET failures = failures + 1, failure = :failure, last_failure_at = unixepoch(), updated_at = unixepoch()
|
||||
WHERE id = :id
|
||||
`, sql.Named("failure", failure), sql.Named("id", id))
|
||||
return err
|
||||
|
||||
@@ -294,6 +294,9 @@ func TestUpdateFeedLastBuild(t *testing.T) {
|
||||
if updated.Failures != 0 {
|
||||
t.Error("expected failures to be cleared")
|
||||
}
|
||||
if updated.LastFailureAt != 0 {
|
||||
t.Error("expected last_failure_at to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFeedFailure(t *testing.T) {
|
||||
@@ -327,6 +330,10 @@ func TestUpdateFeedFailure(t *testing.T) {
|
||||
if updated1.Failures != 1 {
|
||||
t.Errorf("expected failures count to be 1, got %d", updated1.Failures)
|
||||
}
|
||||
if updated1.LastFailureAt == 0 {
|
||||
t.Error("expected last_failure_at to be set after failure")
|
||||
}
|
||||
firstFailureAt := updated1.LastFailureAt
|
||||
|
||||
errorMsg2 := "second error"
|
||||
if err := store.UpdateFeedFailure(feed.ID, errorMsg2); err != nil {
|
||||
@@ -345,6 +352,9 @@ func TestUpdateFeedFailure(t *testing.T) {
|
||||
if updated2.Failures != 2 {
|
||||
t.Errorf("expected failures count to be 2, got %d", updated2.Failures)
|
||||
}
|
||||
if updated2.LastFailureAt < firstFailureAt {
|
||||
t.Error("expected last_failure_at to be monotonic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFeedSiteURLIfEmpty(t *testing.T) {
|
||||
|
||||
@@ -122,6 +122,67 @@ func (s *Store) CreateItem(feedID int64, guid, title, link, content string, pubD
|
||||
return s.GetItem(id)
|
||||
}
|
||||
|
||||
type BatchCreateItemInput struct {
|
||||
GUID string
|
||||
Title string
|
||||
Link string
|
||||
Content string
|
||||
PubDate int64
|
||||
}
|
||||
|
||||
// BatchCreateItemsIgnore inserts items in one transaction and ignores duplicates by (feed_id, guid).
|
||||
// Returns the number of newly inserted rows.
|
||||
func (s *Store) BatchCreateItemsIgnore(feedID int64, inputs []BatchCreateItemInput) (int, error) {
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO items (feed_id, guid, title, link, content, pub_date)
|
||||
VALUES (:feed_id, :guid, :title, :link, :content, :pub_date)
|
||||
ON CONFLICT(feed_id, guid) DO NOTHING
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
created := 0
|
||||
for _, input := range inputs {
|
||||
result, err := stmt.Exec(
|
||||
sql.Named("feed_id", feedID),
|
||||
sql.Named("guid", input.GUID),
|
||||
sql.Named("title", input.Title),
|
||||
sql.Named("link", input.Link),
|
||||
sql.Named("content", input.Content),
|
||||
sql.Named("pub_date", input.PubDate),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected > 0 {
|
||||
created++
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateItemUnread(id int64, unread bool) error {
|
||||
result, err := s.db.Exec(`UPDATE items SET unread = :unread WHERE id = :id`,
|
||||
sql.Named("unread", boolToInt(unread)), sql.Named("id", id))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -39,15 +40,17 @@ func TestMigrationVersionTracking(t *testing.T) {
|
||||
t.Error("schema_migrations table is empty, but migrations should have been applied")
|
||||
}
|
||||
|
||||
// Verify version 1 was applied
|
||||
var applied bool
|
||||
err = store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = 1)").Scan(&applied)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to check version 1: %v", err)
|
||||
}
|
||||
|
||||
if !applied {
|
||||
t.Error("migration version 1 was not applied")
|
||||
// Verify all expected versions were applied
|
||||
versions := []int{1, 2, 3}
|
||||
for _, version := range versions {
|
||||
var applied bool
|
||||
err = store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = :version)", sql.Named("version", version)).Scan(&applied)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to check version %d: %v", version, err)
|
||||
}
|
||||
if !applied {
|
||||
t.Errorf("migration version %d was not applied", version)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE feeds ADD COLUMN last_failure_at INTEGER NOT NULL DEFAULT 0;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX IF NOT EXISTS idx_items_feed_unread ON items(feed_id, unread);
|
||||
@@ -6,17 +6,37 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
"modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
var sqliteHookOnce sync.Once
|
||||
|
||||
func New(dbPath string) (*Store, error) {
|
||||
sqliteHookOnce.Do(func() {
|
||||
sqlite.RegisterConnectionHook(func(conn sqlite.ExecQuerierContext, _ string) error {
|
||||
ctx := context.Background()
|
||||
if _, err := conn.ExecContext(ctx, "PRAGMA foreign_keys = ON", nil); err != nil {
|
||||
return fmt.Errorf("enable foreign_keys: %w", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(ctx, "PRAGMA busy_timeout = 5000", nil); err != nil {
|
||||
return fmt.Errorf("set busy_timeout: %w", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(ctx, "PRAGMA journal_mode = WAL", nil); err != nil {
|
||||
return fmt.Errorf("set journal_mode: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user