Merge branch 'feature/split-tunneling' into release/v2.2.0

This commit is contained in:
Alexandr Stelnykovych
2026-05-18 17:52:10 +03:00
68 changed files with 4990 additions and 138 deletions
+8
View File
@@ -145,10 +145,18 @@ func GetLogLevel() Severity {
// SetLogLevel sets a new log level. Only effective after Start().
func SetLogLevel(level Severity) {
previous := GetLogLevel()
atomic.StoreUint32(logLevel, uint32(level))
// Setup slog here for the transition period.
setupSLog(level)
// Write directly to GlobalWriter (bypassing the level filter) so the
// message is always visible regardless of the current or new level.
if previous != level {
writeLogLevelChange(previous.Name(), level.Name())
}
}
// Name returns the name of the log level.
+22
View File
@@ -83,6 +83,28 @@ func writeVersion() {
}
}
func writeLogLevelChange(from, to string) {
if GlobalWriter == nil {
return
}
if GlobalWriter.isStdout {
fmt.Fprintf(GlobalWriter, "%s%s%s log level changed from %s%s%s to %s%s%s\n",
dimColor(),
time.Now().Format(timeFormat),
endDimColor(),
blueColor(),
from,
endColor(),
blueColor(),
to,
endColor())
} else {
fmt.Fprintf(GlobalWriter, "%s log level changed from %s to %s\n", time.Now().Format(timeFormat), from, to)
}
}
func writerManager() {
defer shutdownWaitGroup.Done()
+42 -31
View File
@@ -5,51 +5,62 @@ import (
"log/slog"
"os"
"runtime"
"sync"
"github.com/lmittmann/tint"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
)
// slogLevel is the shared log level variable for the default slog logger.
// All loggers derived from slog.Default() (e.g. via .With()) share the same
// underlying handler and therefore respect changes to this variable immediately.
var (
slogLevel = new(slog.LevelVar)
slogSetupOnce sync.Once
)
func setupSLog(level Severity) {
// TODO: Changes in the log level are not yet reflected onto the slog handlers in the modules.
// Update the shared level variable. All existing handlers and derived
// loggers read from this pointer, so they pick up the change instantly.
slogLevel.Set(level.toSLogLevel())
// Set highest possible level, so it can be changed in runtime.
handlerLogLevel := level.toSLogLevel()
// Create the handler and set slog.Default() exactly once, so that
// managers created after startup always hold a logger whose underlying
// handler is controlled by slogLevel.
slogSetupOnce.Do(func() {
var logHandler slog.Handler
switch runtime.GOOS {
case "windows":
logHandler = tint.NewHandler(
windowsColoring(GlobalWriter), // Enable coloring on Windows.
&tint.Options{
AddSource: true,
Level: slogLevel,
TimeFormat: timeFormat,
NoColor: !( /* Color: */ GlobalWriter.IsStdout() && isatty.IsTerminal(GlobalWriter.file.Fd())),
},
)
// Create handler depending on OS.
var logHandler slog.Handler
switch runtime.GOOS {
case "windows":
logHandler = tint.NewHandler(
windowsColoring(GlobalWriter), // Enable coloring on Windows.
&tint.Options{
case "linux":
logHandler = tint.NewHandler(GlobalWriter, &tint.Options{
AddSource: true,
Level: handlerLogLevel,
Level: slogLevel,
TimeFormat: timeFormat,
NoColor: !( /* Color: */ GlobalWriter.IsStdout() && isatty.IsTerminal(GlobalWriter.file.Fd())),
},
)
})
case "linux":
logHandler = tint.NewHandler(GlobalWriter, &tint.Options{
AddSource: true,
Level: handlerLogLevel,
TimeFormat: timeFormat,
NoColor: !( /* Color: */ GlobalWriter.IsStdout() && isatty.IsTerminal(GlobalWriter.file.Fd())),
})
default:
logHandler = tint.NewHandler(os.Stdout, &tint.Options{
AddSource: true,
Level: slogLevel,
TimeFormat: timeFormat,
NoColor: true,
})
}
default:
logHandler = tint.NewHandler(os.Stdout, &tint.Options{
AddSource: true,
Level: handlerLogLevel,
TimeFormat: timeFormat,
NoColor: true,
})
}
// Set as default logger.
slog.SetDefault(slog.New(logHandler))
slog.SetDefault(slog.New(logHandler))
})
}
func windowsColoring(lw *LogWriter) io.Writer {
@@ -4,7 +4,7 @@ import { Observable, forkJoin, of } from "rxjs";
import { catchError, map, mergeMap } from "rxjs/operators";
import { AppProfileService } from "./app-profile.service";
import { AppProfile } from "./app-profile.types";
import { DNSContext, IPScope, Reason, TLSContext, TunnelContext, Verdict } from "./network.types";
import { DNSContext, IPScope, Reason, SplitTunContext, TLSContext, TunnelContext, Verdict } from "./network.types";
import { PORTMASTER_HTTP_API_ENDPOINT, PortapiService } from "./portapi.service";
import { Container } from "postcss";
@@ -162,6 +162,7 @@ export interface NetqueryConnection {
blockedEntities?: string[];
reason?: Reason;
tunnel?: TunnelContext;
split_tun?: SplitTunContext;
dns?: DNSContext;
tls?: TLSContext;
};
@@ -459,6 +460,7 @@ export class Netquery {
case Verdict.Accept:
case Verdict.RerouteToNs:
case Verdict.RerouteToTunnel:
case Verdict.RerouteToSplitTun:
case Verdict.Undeterminable:
stats.size += res.totalCount
stats.countAllowed += res.totalCount;
@@ -8,7 +8,8 @@ export enum Verdict {
Drop = 4,
RerouteToNs = 5,
RerouteToTunnel = 6,
Failed = 7
Failed = 7,
RerouteToSplitTun = 8
}
export enum IPProtocol {
@@ -209,6 +210,13 @@ export interface TunnelContext {
RoutingAlg: 'default';
}
export interface SplitTunContext {
// Interface is the name of the network interface the connection is bound to.
Interface: string;
// IP is the IP address used to bind the connection to the interface.
IP: string;
}
export interface GeoIPInfo {
IP: string;
Country: string;
+2
View File
@@ -29,6 +29,7 @@ import { AppOverviewComponent, AppViewComponent, QuickSettingInternetButtonCompo
import { QsHistoryComponent } from './pages/app-view/qs-history/qs-history.component';
import { QuickSettingSelectExitButtonComponent } from './pages/app-view/qs-select-exit/qs-select-exit';
import { QuickSettingUseSPNButtonComponent } from './pages/app-view/qs-use-spn/qs-use-spn';
import { QuickSettingUseSplitTunButtonComponent } from './pages/app-view/qs-use-splittun/qs-use-splittun';
import { DashboardPageComponent } from './pages/dashboard/dashboard.component';
import { FeatureCardComponent } from './pages/dashboard/feature-card/feature-card.component';
import { MonitorPageComponent } from './pages/monitor';
@@ -138,6 +139,7 @@ const localeConfig = {
QuickSettingInternetButtonComponent,
QuickSettingUseSPNButtonComponent,
QuickSettingSelectExitButtonComponent,
QuickSettingUseSplitTunButtonComponent,
AppOverviewComponent,
PlaceholderComponent,
LoadingComponent,
@@ -76,6 +76,9 @@
<app-qs-select-exit [canUse]="canUseSPN" [settings]="profileSettings" (save)="saveSetting($event)">
</app-qs-select-exit>
<app-qs-use-splittun [settings]="profileSettings" (save)="saveSetting($event)">
</app-qs-use-splittun>
<button class="flex flex-row gap-2 items-center px-4 bg-gray-300 btn" cdkOverlayOrigin #overlayOrigin="cdkOverlayOrigin" (click)="profileMenu.dropdown.toggle(overlayOrigin)">
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" class="w-4 h-4">
<path stroke-linecap="round" stroke-linejoin="round" d="M10.343 3.94c.09-.542.56-.94 1.11-.94h1.093c.55 0 1.02.398 1.11.94l.149.894c.07.424.384.764.78.93.398.164.855.142 1.205-.108l.737-.527a1.125 1.125 0 011.45.12l.773.774c.39.389.44 1.002.12 1.45l-.527.737c-.25.35-.272.806-.107 1.204.165.397.505.71.93.78l.893.15c.543.09.94.56.94 1.109v1.094c0 .55-.397 1.02-.94 1.11l-.893.149c-.425.07-.765.383-.93.78-.165.398-.143.854.107 1.204l.527.738c.32.447.269 1.06-.12 1.45l-.774.773a1.125 1.125 0 01-1.449.12l-.738-.527c-.35-.25-.806-.272-1.203-.107-.397.165-.71.505-.781.929l-.149.894c-.09.542-.56.94-1.11.94h-1.094c-.55 0-1.019-.398-1.11-.94l-.148-.894c-.071-.424-.384-.764-.781-.93-.398-.164-.854-.142-1.204.108l-.738.527c-.447.32-1.06.269-1.45-.12l-.773-.774a1.125 1.125 0 01-.12-1.45l.527-.737c.25-.35.273-.806.108-1.204-.165-.397-.505-.71-.93-.78l-.894-.15c-.542-.09-.94-.56-.94-1.109v-1.094c0-.55.398-1.02.94-1.11l.894-.149c.424-.07.765-.383.93-.78.165-.398.143-.854-.107-1.204l-.527-.738a1.125 1.125 0 01.12-1.45l.773-.773a1.125 1.125 0 011.45-.12l.737.527c.35.25.807.272 1.204.107.397-.165.71-.505.78-.929l.15-.894z" />
@@ -0,0 +1,40 @@
<div
class="relative flex flex-wrap items-center justify-center w-full h-full gap-2 px-3 py-2 bg-gray-300 border border-gray-300 rounded shadow"
snfgTooltipPosition="right"
[sfng-tooltip]="(interferingSettings.length > 0 || spnEnabled) ? tooltipTemplate : null">
<span class="text-primary" [class.cursor-pointer]="interferingSettings.length > 0 || spnEnabled">
Split Tunnel
</span>
<sfng-toggle *ngIf="splitTunModuleEnabled === true"
[ngModel]="currentValue" (ngModelChange)="updateUseSplitTun($event)">
</sfng-toggle>
<span *ngIf="splitTunModuleEnabled === false"
routerLink="/settings" [queryParams]="{setting: 'splittun/enable'}" class="cursor-pointer text-tertiary hover:underline"
sfng-tooltip="Enable Split Tunneling to use.">
Disabled
</span>
<ng-template *ngIf="splitTunModuleEnabled === null">
<fa-icon icon="spinner" [spin]="true"></fa-icon>
</ng-template>
<span class="absolute right-0 block w-2 h-2 bg-yellow-300 border border-gray-100 rounded opacity-75"
[style.background-color]="spnFullOverride ? '#ef4444' : null"
style="top: 2px; transform: translateX(-2px)" *ngIf="interferingSettings.length > 0 || spnEnabled"></span>
</div>
<ng-template #tooltipTemplate>
Settings that may interfere with Split Tunnel:
<ul class="pl-4 list-disc">
<li *ngIf="spnEnabled && spnFullOverride">SPN is routing all traffic, fully bypassing Split Tunnel</li>
<li *ngIf="spnEnabled && !spnFullOverride">SPN is routing non-excluded connections, partially bypassing Split Tunnel</li>
<ng-container *ngFor="let setting of interferingSettings">
<li class="cursor-pointer hover:underline" [routerLink]="[]"
[queryParams]="{setting: setting.Key, tab: 'settings'}">
{{ setting.Name }}
</li>
</ng-container>
</ul>
</ng-template>
@@ -0,0 +1,119 @@
import { ChangeDetectionStrategy, ChangeDetectorRef, Component, DestroyRef, EventEmitter, Input, OnChanges, OnInit, Output, SimpleChanges, inject } from "@angular/core";
import { takeUntilDestroyed } from "@angular/core/rxjs-interop";
import { BoolSetting, ConfigService, Setting, StringArraySetting, getActualValue } from "@safing/portmaster-api";
import { SaveSettingEvent } from "src/app/shared/config/generic-setting/generic-setting";
const configKeys = {
splitTunUse: 'splittun/use',
splitTunEnable: 'splittun/enable',
splitTunUsagePolicy: 'splittun/usagePolicy',
spnUse: 'spn/use',
spnEnable: 'spn/enable',
spnUsagePolicy: 'spn/usagePolicy',
} as const;
@Component({
selector: 'app-qs-use-splittun',
templateUrl: './qs-use-splittun.html',
changeDetection: ChangeDetectionStrategy.OnPush
})
export class QuickSettingUseSplitTunButtonComponent implements OnInit, OnChanges {
private destroyRef = inject(DestroyRef);
@Input()
settings: Setting[] = [];
@Output()
save = new EventEmitter<SaveSettingEvent>();
currentValue = false;
/** App-level settings (Exclude rules in usagePolicy) that may interfere. */
interferingSettings: Setting[] = [];
/** Whether the Split Tunneling module is globally enabled. null = not yet loaded. */
splitTunModuleEnabled: boolean | null = null;
/** Whether SPN is enabled for this app — overrides split tunnel for SPN-routed connections. */
spnEnabled = false;
/** Whether SPN fully overrides split tunnel: SPN in use with no Exclude rules in spn/usagePolicy. */
spnFullOverride = false;
/** Whether the SPN module is globally enabled. */
private spnModuleEnabled = false;
constructor(
private configService: ConfigService,
private cdr: ChangeDetectorRef
) { }
ngOnChanges(changes: SimpleChanges): void {
if ('settings' in changes) {
this.currentValue = false;
const useSetting = this.settings.find(s => s.Key === configKeys.splitTunUse) as BoolSetting | undefined;
if (useSetting) {
this.currentValue = getActualValue(useSetting);
}
this.updateInterfering();
}
}
ngOnInit(): void {
this.configService.watch<BoolSetting>(configKeys.splitTunEnable)
.pipe(takeUntilDestroyed(this.destroyRef))
.subscribe(value => {
this.splitTunModuleEnabled = !!value;
this.updateInterfering();
this.cdr.markForCheck();
});
this.configService.watch<BoolSetting>(configKeys.spnEnable)
.pipe(takeUntilDestroyed(this.destroyRef))
.subscribe(value => {
this.spnModuleEnabled = !!value;
this.updateInterfering();
this.cdr.markForCheck();
});
}
updateUseSplitTun(enabled: boolean): void {
this.save.next({
isDefault: false,
key: configKeys.splitTunUse,
value: enabled,
});
}
private updateInterfering(): void {
this.interferingSettings = [];
this.spnEnabled = false;
this.spnFullOverride = false;
if (!this.currentValue || !this.splitTunModuleEnabled) {
return;
}
const spnUseSetting = this.settings.find(s => s.Key === configKeys.spnUse) as BoolSetting | undefined;
this.spnEnabled = this.spnModuleEnabled && !!spnUseSetting && !!getActualValue(spnUseSetting);
// If SPN is enabled, check if it fully overrides Split Tunnel (no Exclude rules in SPN policy)
if (this.spnEnabled) {
const spnPolicy = this.settings.find(s => s.Key === configKeys.spnUsagePolicy) as StringArraySetting | undefined;
const spnPolicyValue = spnPolicy ? getActualValue(spnPolicy) : [];
const hasSpnExcludeRules = Array.isArray(spnPolicyValue) && spnPolicyValue.some(rule => rule.startsWith('- ') || rule === '-');
this.spnFullOverride = !hasSpnExcludeRules;
}
// Exclude rules in usagePolicy may prevent some connections from being tunneled
const usagePolicy = this.settings.find(s => s.Key === configKeys.splitTunUsagePolicy) as StringArraySetting | undefined;
if (usagePolicy) {
const value = getActualValue(usagePolicy);
if (Array.isArray(value) && value.some(rule => rule.startsWith('- ') || rule === '-')) {
this.interferingSettings.push(usagePolicy);
}
}
}
}
@@ -6,6 +6,9 @@
<sfng-tipup key="dashboardIntro" placement="left"></sfng-tipup>
</h1>
<!--
\\ DISABLED because: Displaying the SPN username in a prominent location may raise security and privacy concerns for some users.
<span class="text-sm font-normal text-secondary">
<ng-container *ngIf="!!profile; else: noUsername">
Welcome back, <span class="text-primary">{{ profile.username }}</span>!
@@ -16,6 +19,7 @@
Welcome back!
</ng-template>
</span>
-->
</div>
<div class="flex flex-row gap-8">
@@ -73,10 +73,10 @@ export class FeatureCardComponent implements OnChanges, OnDestroy {
}
let key: string | undefined;
if (this.feature?.ConfigScope) {
key = 'config:' + this.feature?.ConfigScope;
if (this.feature?.ConfigKey) {
key = this.feature?.ConfigKey;
} else {
key = this.feature?.ConfigKey;
key = 'config:' + this.feature?.ConfigScope;
}
if (!key) {
@@ -427,11 +427,16 @@ export class ConfigSettingsViewComponent
(s) => s.Key === subsys.ToggleOptionKey
);
if (!!toggleOption) {
if (
(toggleOption.Value !== undefined && !toggleOption.Value) ||
(toggleOption.Value === undefined &&
!toggleOption.DefaultValue)
) {
// Determine the effective enabled state: per-app value takes
// priority, then the globally-configured value (GlobalDefault),
// and finally the hardcoded DefaultValue.
const effectiveEnabled =
toggleOption.Value !== undefined
? !!toggleOption.Value
: toggleOption.GlobalDefault !== undefined
? !!toggleOption.GlobalDefault
: !!toggleOption.DefaultValue;
if (!effectiveEnabled) {
subsys.isDisabled = true;
// remove all settings for all subsystem categories
@@ -7,7 +7,7 @@ export interface SubsystemWithExpertise extends Subsystem {
hasUserDefinedValues: boolean;
}
export var subsystems : SubsystemWithExpertise[] = [
export const subsystems : SubsystemWithExpertise[] = [
{
minimumExpertise: ExpertiseLevelNumber.developer,
isDisabled: false,
@@ -268,5 +268,30 @@ export var subsystems : SubsystemWithExpertise[] = [
Deleted: 0,
Key: "runtime:subsystems/spn"
}
},
{
minimumExpertise: ExpertiseLevelNumber.user, // User level since UI is user-facing
isDisabled: false,
hasUserDefinedValues: false,
ID: "splittun",
Name: "Split Tunnel",
Description: "Route traffic through specified interface to bypass default routing",
Modules: [
{
Name: "splittun",
Enabled: true
}
],
ToggleOptionKey: "splittun/enable", // Links to the boolean enable/disable option
ExpertiseLevel: "user",
ReleaseLevel: 0,
ConfigKeySpace: "config:splittun/",
_meta: {
Created: 0,
Modified: 0,
Expires: 0,
Deleted: 0,
Key: "runtime:subsystems/splittun"
}
}
];
@@ -221,7 +221,7 @@
</span>
</div>
<div *ngIf="conn.scope === scopes.Global">
<div *ngIf="conn.scope === scopes.Global && !!conn.tunneled">
<h3 class="text-primary text-xxs">SPN Tunnel</h3>
<ng-container [ngSwitch]="true">
<span *ngSwitchCase="!conn.tunneled" class="inline-flex items-center gap-2 text-secondary">
@@ -255,6 +255,17 @@
</ng-container>
</div>
<div *ngIf="!!conn.extra_data?.split_tun">
<h3 class="text-primary text-xxs">Split Tunnel</h3>
<div *ngIf="conn.extra_data?.split_tun as splitTun" class="meta">
<span class="inline-flex items-center gap-1 flex-wrap">
<span class="text-secondary">This connection is routed through interface</span>
<span>{{ splitTun.Interface }}</span>
<span class="text-secondary">({{ splitTun.IP }})</span>
</span>
</div>
</div>
<div *ngIf="!!bwData.length" class="col-span-3 block border-t border-gray-400 py-2">
<h2 class="text-secondary uppercase w-full text-center text-xxs">Data Usage</h2>
<sfng-netquery-line-chart class="block w-full !h-36" [data]="bwData" [config]="{
@@ -1206,7 +1206,8 @@ export class SfngNetqueryViewer implements OnInit, OnDestroy, AfterViewInit {
$in: [
Verdict.Accept,
Verdict.RerouteToNs,
Verdict.RerouteToTunnel
Verdict.RerouteToTunnel,
Verdict.RerouteToSplitTun
],
}
},
+2 -1
View File
@@ -28,7 +28,8 @@ span.verdict {
&.accept,
&.reroutetons,
&.reroutetotunnel {
&.reroutetotunnel,
&.reroutetosplittun {
--bg-color: theme('colors.info.green');
}
+114
View File
@@ -0,0 +1,114 @@
#!/usr/bin/env bash
set -euo pipefail
# This script builds the Angular project for the Portmaster application and packages it into a zip file.
# The script assumes that all necessary dependencies are installed and available.
# Output file: dist/portmaster.zip
DEVELOPMENT=false
INTERACTIVE=false
usage() {
cat <<'EOF'
Usage: build_angular.sh [options]
Options:
-d, --development Build Angular and libs in development mode
-i, --interactive Ask before running install and libs build steps
-h, --help Show this help message
EOF
}
have() {
command -v "$1" >/dev/null 2>&1
}
ask_yes_no_default_yes() {
local prompt=$1
local reply
read -r -p "$prompt (Y/N, default: Y) " reply
reply=${reply:-Y}
[[ ! $reply =~ ^[Nn]$ ]]
}
while [[ $# -gt 0 ]]; do
case "$1" in
-d|--development)
DEVELOPMENT=true
shift
;;
-i|--interactive)
INTERACTIVE=true
shift
;;
-h|--help)
usage
exit 0
;;
*)
echo "Unknown argument: $1" >&2
usage
exit 2
;;
esac
done
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd -- "${SCRIPT_DIR}/../../../" && pwd)"
OUTPUT_DIR="${SCRIPT_DIR}/dist"
ORIGINAL_DIR="$(pwd)"
cleanup() {
cd "${ORIGINAL_DIR}" >/dev/null 2>&1 || true
}
trap cleanup EXIT
mkdir -p "${OUTPUT_DIR}"
if ! have npm; then
echo "Error: npm not found in PATH." >&2
exit 1
fi
if ! have zip; then
echo "Error: zip not found in PATH." >&2
echo "Install it using your package manager (for example: sudo apt install zip)." >&2
exit 1
fi
cd "${PROJECT_ROOT}/desktop/angular"
if ! $INTERACTIVE || ask_yes_no_default_yes "Run 'npm install'?"; then
npm install
fi
if ! $INTERACTIVE || ask_yes_no_default_yes "Build shared libraries?"; then
if $DEVELOPMENT; then
echo "Building shared libraries in development mode"
npm run build-libs:dev
else
echo "Building shared libraries in production mode"
npm run build-libs
fi
fi
if $DEVELOPMENT; then
echo "Building Angular project in development mode"
./node_modules/.bin/ng build --configuration development --base-href /ui/modules/portmaster/ portmaster
else
echo "Building Angular project in production mode"
NODE_ENV=production ./node_modules/.bin/ng build --configuration production --base-href /ui/modules/portmaster/ portmaster
fi
DESTINATION_ZIP="${OUTPUT_DIR}/portmaster.zip"
echo "Creating zip archive"
rm -f "${DESTINATION_ZIP}"
(
cd dist
zip -r "${DESTINATION_ZIP}" .
)
echo "Build completed successfully: ${DESTINATION_ZIP}"
echo
echo "To replace the currently installed UI bundle, use:"
echo " sudo cp -f /usr/lib/portmaster/portmaster.zip /usr/lib/portmaster/portmaster.zip.bak"
echo " sudo cp -f \"${DESTINATION_ZIP}\" /usr/lib/portmaster/portmaster.zip"
+97 -5
View File
@@ -12,6 +12,7 @@ import (
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
pmpacket "github.com/safing/portmaster/service/network/packet"
)
var nfct *ct.Nfct // Conntrack handler. NFCT: Network Filter Connection Tracking.
@@ -33,6 +34,91 @@ func TeardownNFCT() {
}
}
// DeleteUnmarkedConnections deletes all conntrack entries with connmark=0,
// excluding loopback connections.
// These entries represent connections established while Portmaster was not
// running or was paused and therefore never received a verdict mark.
//
// The Linux netfilter nat table applies DNAT only to the first packet of a NEW
// connection. ESTABLISHED connections bypass the nat table entirely, so any
// routing decision (e.g. MarkRerouteSPN, MarkRerouteSplitTun) would never take
// effect for them. Removing their conntrack entries forces applications to
// reconnect; the resulting SYN is processed by NFQUEUE as a new connection and
// the correct DNAT rule fires.
//
// Loopback connections (source or destination is 127.x.x.x / ::1) are skipped.
// They always carry connmark=0 because Portmaster never saves a permanent mark
// for loopback-destined packets. Flushing them would needlessly disconnect apps
// talking to local services (databases, dev servers, local APIs, etc.).
//
// Connections already processed by Portmaster carry a non-zero connmark and
// are handled via CONNMARK --restore-mark; they are unaffected.
func DeleteUnmarkedConnections() error {
if nfct == nil {
return errors.New("nfq: nfct not initialized")
}
deleted := deleteUnmarkedConnections(nfct, ct.IPv4)
if netenv.IPv6Enabled() {
deleted += deleteUnmarkedConnections(nfct, ct.IPv6)
}
log.Infof("nfq: deleted %d unmarked conntrack entries to force re-evaluation on firewall activation", deleted)
return nil
}
func deleteUnmarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) {
filter := ct.FilterAttr{
Mark: []byte{0x00, 0x00, 0x00, 0x00},
MarkMask: []byte{0xFF, 0xFF, 0xFF, 0xFF},
}
connections, err := nfct.Query(ct.Conntrack, f, filter)
if err != nil {
log.Warningf("nfq: error querying unmarked conntrack entries: %s", err)
return 0
}
var lastErr error
for _, connection := range connections {
if isLoopbackConnection(connection) {
continue
}
if err := nfct.Delete(ct.Conntrack, f, connection); err != nil {
lastErr = err
} else {
deleted++
}
}
if lastErr != nil {
log.Warningf("nfq: some unmarked conntrack entries could not be deleted, last error: %s", lastErr)
}
return deleted
}
// isLoopbackConnection reports whether a conntrack entry involves a loopback address.
func isLoopbackConnection(c ct.Con) bool {
if c.Origin != nil {
if c.Origin.Src != nil && c.Origin.Src.IsLoopback() {
return true
}
if c.Origin.Dst != nil && c.Origin.Dst.IsLoopback() {
return true
}
}
if c.Reply != nil {
if c.Reply.Src != nil && c.Reply.Src.IsLoopback() {
return true
}
if c.Reply.Dst != nil && c.Reply.Dst.IsLoopback() {
return true
}
}
return false
}
// DeleteAllMarkedConnection deletes all marked entries from the conntrack table.
func DeleteAllMarkedConnection() error {
if nfct == nil {
@@ -53,7 +139,7 @@ func DeleteAllMarkedConnection() error {
func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) {
// initialize variables
permanentFlags := []uint32{MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN}
permanentFlags := []uint32{MarkAcceptAlways, MarkBlockAlways, MarkDropAlways, MarkRerouteNS, MarkRerouteSPN, MarkRerouteSplitTun}
filter := ct.FilterAttr{}
filter.MarkMask = []byte{0xFF, 0xFF, 0xFF, 0xFF}
filter.Mark = []byte{0x00, 0x00, 0x00, 0x00} // 4 zeros starting value
@@ -70,8 +156,8 @@ func deleteMarkedConnections(nfct *ct.Nfct, f ct.Family) (deleted int) {
}
for _, connection := range currentConnections {
deleteError = nfct.Delete(ct.Conntrack, ct.IPv4, connection)
if err != nil {
deleteError = nfct.Delete(ct.Conntrack, f, connection)
if deleteError != nil {
numberOfErrors++
} else {
deleted++
@@ -102,7 +188,13 @@ func DeleteMarkedConnection(conn *network.Connection) error {
},
},
}
connections, err := nfct.Get(ct.Conntrack, ct.IPv4, con)
family := ct.IPv4
if conn.IPVersion == pmpacket.IPv6 {
family = ct.IPv6
}
connections, err := nfct.Get(ct.Conntrack, family, con)
if err != nil {
return fmt.Errorf("nfq: failed to find entry for connection %s: %w", conn.String(), err)
}
@@ -112,7 +204,7 @@ func DeleteMarkedConnection(conn *network.Connection) error {
}
for _, connection := range connections {
deleteErr := nfct.Delete(ct.Conntrack, ct.IPv4, connection)
deleteErr := nfct.Delete(ct.Conntrack, family, connection)
if err == nil {
err = deleteErr
}
+16 -9
View File
@@ -18,15 +18,16 @@ import (
// See TODO on packet.mark() on their relevance
// and a possibility to remove most IPtables rules.
const (
MarkAccept = 1700 // 0x6a4
MarkBlock = 1701 // 0x6a5
MarkDrop = 1702 // 0x6a6
MarkAcceptFinal = 1709 // 0x6ad Accept and finalize the decision in iptables. This should only be used for Portmaster-owned outbound connections.
MarkAcceptAlways = 1710 // 0x6ae
MarkBlockAlways = 1711 // 0x6af
MarkDropAlways = 1712 // 0x6b0
MarkRerouteNS = 1799 // 0x707
MarkRerouteSPN = 1717 // 0x6b5
MarkAccept = 1700 // 0x6a4
MarkBlock = 1701 // 0x6a5
MarkDrop = 1702 // 0x6a6
MarkAcceptFinal = 1709 // 0x6ad Accept and finalize the decision in iptables. This should only be used for Portmaster-owned outbound connections.
MarkAcceptAlways = 1710 // 0x6ae
MarkBlockAlways = 1711 // 0x6af
MarkDropAlways = 1712 // 0x6b0
MarkRerouteNS = 1799 // 0x707
MarkRerouteSPN = 1717 // 0x6b5
MarkRerouteSplitTun = 1719 // 0x6b7
)
func markToString(mark int) string {
@@ -49,6 +50,8 @@ func markToString(mark int) string {
return "RerouteNS"
case MarkRerouteSPN:
return "RerouteSPN"
case MarkRerouteSplitTun:
return "RerouteSplitTun"
}
return "unknown"
}
@@ -192,3 +195,7 @@ func (pkt *packet) RerouteToNameserver() error {
func (pkt *packet) RerouteToTunnel() error {
return pkt.mark(MarkRerouteSPN)
}
func (pkt *packet) RerouteToSplitTun() error {
return pkt.mark(MarkRerouteSplitTun)
}
+21 -4
View File
@@ -85,12 +85,14 @@ func init() {
"filter PORTMASTER-FILTER -m mark --mark 1711 -p icmp -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1711 -j REJECT --reject-with icmp-admin-prohibited",
"filter PORTMASTER-FILTER -m mark --mark 1712 -j DROP",
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN", // informational (non-functional) RETURN verdicts at the end of the chain
"filter PORTMASTER-FILTER -m mark --mark 1719 -j RETURN", // informational (non-functional) RETURN verdicts at the end of the chain
"nat PORTMASTER-REDIRECT -m mark --mark 1799 -p udp -j DNAT --to 127.0.0.17:53",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p tcp -j DNAT --to 127.0.0.17:717",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p udp -j DNAT --to 127.0.0.17:717",
// "nat PORTMASTER-REDIRECT -m mark --mark 1717 ! -p tcp ! -p udp -j DNAT --to 127.0.0.17",
"nat PORTMASTER-REDIRECT -m mark --mark 1719 -p tcp -j DNAT --to 127.0.0.17:719",
"nat PORTMASTER-REDIRECT -m mark --mark 1719 -p udp -j DNAT --to 127.0.0.17:719",
}
v4once = []string{
@@ -132,12 +134,14 @@ func init() {
"filter PORTMASTER-FILTER -m mark --mark 1711 -p icmpv6 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1711 -j REJECT --reject-with icmp6-adm-prohibited",
"filter PORTMASTER-FILTER -m mark --mark 1712 -j DROP",
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN",
"filter PORTMASTER-FILTER -m mark --mark 1717 -j RETURN", // informational (non-functional) RETURN verdicts at the end of the chain
"filter PORTMASTER-FILTER -m mark --mark 1719 -j RETURN", // informational (non-functional) RETURN verdicts at the end of the chain
"nat PORTMASTER-REDIRECT -m mark --mark 1799 -p udp -j DNAT --to [::1]:53",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p tcp -j DNAT --to [::1]:717",
"nat PORTMASTER-REDIRECT -m mark --mark 1717 -p udp -j DNAT --to [::1]:717",
// "nat PORTMASTER-REDIRECT -m mark --mark 1717 ! -p tcp ! -p udp -j DNAT --to [::1]",
"nat PORTMASTER-REDIRECT -m mark --mark 1719 -p tcp -j DNAT --to [::1]:719",
"nat PORTMASTER-REDIRECT -m mark --mark 1719 -p udp -j DNAT --to [::1]:719",
}
v6once = []string{
@@ -167,8 +171,21 @@ func activateNfqueueFirewall() error {
if err := nfq.InitNFCT(); err != nil {
return err
}
// Remove stale conntrack entries carrying Portmaster marks.
// This is required to prevent conflicts with existing entries if Portmaster was not cleanly stopped,
// and to ensure a clean state on firewall activation.
_ = nfq.DeleteAllMarkedConnection()
// Force re-evaluation of connections that bypassed Portmaster while it was
// stopped or paused. Without this, DNAT rules (SPN, Split-Tunneling) would
// never apply to already-established connections, as the nat table is only
// traversed for new connections.
//
// NOTE: This will disconnect all existing non-loopback connections with mark=0!
//
// TODO: normally, this is only necessary when DNAT-based routing features are active (e.g. SPN, Split-Tunneling)
_ = nfq.DeleteUnmarkedConnections()
return nil
}
@@ -65,3 +65,8 @@ func (p *tracedPacket) RerouteToTunnel() error {
defer p.markServed("reroute-tunnel")
return p.Packet.RerouteToTunnel()
}
func (p *tracedPacket) RerouteToSplitTun() error {
defer p.markServed("reroute-splittun")
return p.Packet.RerouteToSplitTun()
}
@@ -135,3 +135,11 @@ func (pkt *Packet) RerouteToTunnel() error {
}
return nil
}
// RerouteToSplitTun permanently reroutes the connection to the split tunnel (and the current packet).
func (pkt *Packet) RerouteToSplitTun() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, network.VerdictRerouteToSplitTun)
}
return nil
}
@@ -183,6 +183,8 @@ func getKextVerdictFromConnection(conn *network.Connection) kextinterface.KextVe
return kextinterface.VerdictRerouteToNameserver
case network.VerdictRerouteToTunnel:
return kextinterface.VerdictRerouteToTunnel
case network.VerdictRerouteToSplitTun:
return kextinterface.VerdictRerouteToSplitTun
case network.VerdictFailed:
return kextinterface.VerdictFailed
}
@@ -140,3 +140,11 @@ func (pkt *Packet) RerouteToTunnel() error {
}
return nil
}
// RerouteToSplitTun permanently reroutes the connection to the local split tunnel entrypoint (and the current packet).
func (pkt *Packet) RerouteToSplitTun() error {
if pkt.verdictSet.SetToIf(false, true) {
return SetVerdict(pkt, kextinterface.VerdictRerouteToSplitTun)
}
return nil
}
+28 -1
View File
@@ -25,6 +25,7 @@ import (
"github.com/safing/portmaster/service/network/netutils"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/resolver"
"github.com/safing/portmaster/spn/access"
)
@@ -464,6 +465,7 @@ func filterHandler(conn *network.Connection, pkt packet.Packet) {
}
filterConnection := true
checkTunnel := true
// Check for special (internal) connection cases.
switch {
@@ -474,6 +476,14 @@ func filterHandler(conn *network.Connection, pkt packet.Packet) {
filterConnection = false
log.Tracer(pkt.Ctx()).Infof("filter: granting own pre-authenticated connection %s", conn)
case !conn.Inbound && isOwnSplitTunnelProxyConnection(conn):
// Approve connection and skip tunnel check.
conn.Accept("split tunnel connection proxied by Portmaster", noReasonOptionKey)
conn.Internal = true
filterConnection = false
checkTunnel = false
log.Tracer(pkt.Ctx()).Infof("filter: granting own pre-authenticated proxied split tunnel connection %s", conn)
// Redirect outbound DNS packets if enabled,
case dnsQueryInterception() &&
!module.instance.Resolver().IsDisabled() &&
@@ -505,7 +515,7 @@ func filterHandler(conn *network.Connection, pkt packet.Packet) {
}
// Apply privacy filter and check tunneling.
FilterConnection(pkt.Ctx(), conn, pkt, filterConnection, true)
FilterConnection(pkt.Ctx(), conn, pkt, filterConnection, checkTunnel)
// Decide how to continue handling connection.
switch {
@@ -529,6 +539,10 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.
return
}
// Always fetch location data so Country/ASN/ASOrg are available in the UI
// regardless of whether filtering or tunneling is active.
conn.Entity.FetchLocation(ctx)
// Check if external verdict handler is set, and if so, run it.
// Note! This block can override the filter and tunnel check flags!
if extHandler := externalVerdictHandler.Load(); extHandler != nil {
@@ -571,6 +585,11 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.
// Check if connection should be tunneled.
if checkTunnel {
checkTunneling(ctx, conn)
if conn.Verdict != network.VerdictRerouteToTunnel {
// SPN takes precedence over Split Tunnel, so only check split tunneling if not already set to tunnel.
checkSplitTunneling(ctx, conn)
}
}
// Request tunneling if no tunnel is set and connection should be tunneled.
@@ -585,6 +604,12 @@ func FilterConnection(ctx context.Context, conn *network.Connection, pkt packet.
// connection and the data will help with debugging and displaying in the UI.
conn.Failed(fmt.Sprintf("failed to request tunneling: %s", err), "")
}
} else if conn.Verdict == network.VerdictRerouteToSplitTun {
// Request split tunneling
err := requestSplitTunneling(ctx, conn)
if err != nil {
conn.Failed(fmt.Sprintf("failed to request split-tunneling: %s", err), profile.CfgOptionSplitTunUseKey)
}
}
}
@@ -844,6 +869,8 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V
err = pkt.RerouteToNameserver()
case network.VerdictRerouteToTunnel:
err = pkt.RerouteToTunnel()
case network.VerdictRerouteToSplitTun:
err = pkt.RerouteToSplitTun()
case network.VerdictFailed:
atomic.AddUint64(packetsFailed, 1)
err = pkt.Drop()
+121
View File
@@ -0,0 +1,121 @@
package firewall
import (
"context"
"errors"
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/profile/endpoints"
"github.com/safing/portmaster/service/splittun"
)
func checkSplitTunneling(ctx context.Context, conn *network.Connection) {
// Check if the connection should be tunneled at all.
switch {
case conn.Entity.IPScope.IsLocalhost():
// Can't tunnel Local connections.
return
case conn.Inbound:
// Can't tunnel incoming connections.
return
case conn.Verdict != network.VerdictAccept:
// Connection will be blocked.
return
case conn.IPProtocol != packet.TCP && conn.IPProtocol != packet.UDP:
// Unsupported protocol.
return
case conn.Process().Pid == ownPID:
// Bypass tunneling for own connections.
return
case !splittun.IsReady():
return
}
// Get profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
conn.Failed("no profile set", "")
return
}
// Update profile.
if layeredProfile.NeedsUpdate() {
// Update revision counter in connection.
conn.ProfileRevisionCounter = layeredProfile.Update(
conn.Process().MatchingData(),
conn.Process().CreateProfileCallback,
)
conn.SaveWhenFinished()
} else {
// Check if the revision counter of the connection needs updating.
revCnt := layeredProfile.RevisionCnt()
if conn.ProfileRevisionCounter != revCnt {
conn.ProfileRevisionCounter = revCnt
conn.SaveWhenFinished()
}
}
// Check if split-tunneling is enabled for this app at all.
if !layeredProfile.UseSplitTun() {
return
}
// Check if tunneling is enabled for entity.
conn.Entity.FetchData(ctx)
result, _ := layeredProfile.MatchSplitTunUsagePolicy(ctx, conn.Entity)
switch result {
case endpoints.MatchError:
conn.Failed("failed to check Split Tunnel rules", profile.CfgOptionSplitTunUsagePolicyKey)
return
case endpoints.Denied:
return
case endpoints.Permitted, endpoints.NoMatch:
}
conn.SaveWhenFinished()
conn.SetVerdictDirectly(network.VerdictRerouteToSplitTun)
}
func requestSplitTunneling(ctx context.Context, conn *network.Connection) error {
// Get profile.
layeredProfile := conn.Process().Profile()
if layeredProfile == nil {
return errors.New("no profile set")
}
interfaceToBind := layeredProfile.SplitTunInterface()
// Queue request in splittun module.
splitTunCtx, err := splittun.AwaitRequest(conn, interfaceToBind)
if err != nil {
return err
}
// Store context on the connection so the UI can display interface information.
conn.SplitTunContext = splitTunCtx
log.Tracer(ctx).Trace("filter: split tunneling requested")
return nil
}
func isOwnSplitTunnelProxyConnection(conn *network.Connection) bool {
switch {
case conn.Process().Pid != ownPID:
// Proxies are running only in our own process.
return false
case conn.Entity.IPScope.IsLocalhost():
// Local connections are not proxied.
return false
case conn.IPProtocol != packet.TCP && conn.IPProtocol != packet.UDP:
// Unsupported protocol.
return false
case !splittun.IsReady():
return false
}
return splittun.IsProxiedConnectionInfo(conn)
}
+10
View File
@@ -35,6 +35,7 @@ import (
"github.com/safing/portmaster/service/process"
"github.com/safing/portmaster/service/profile"
"github.com/safing/portmaster/service/resolver"
"github.com/safing/portmaster/service/splittun"
"github.com/safing/portmaster/service/status"
"github.com/safing/portmaster/service/sync"
"github.com/safing/portmaster/service/ui"
@@ -102,6 +103,8 @@ type Instance struct {
control *control.Control
interop *interop.Interoperability
splittun *splittun.SplitTunModule
access *access.Access
// SPN modules
@@ -283,6 +286,11 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { //nolint:maintidx
return instance, fmt.Errorf("create access module: %w", err)
}
instance.splittun, err = splittun.New(instance)
if err != nil {
return instance, fmt.Errorf("create splittun module: %w", err)
}
// SPN modules
instance.cabin, err = cabin.New(instance)
if err != nil {
@@ -355,6 +363,8 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { //nolint:maintidx
instance.filterLists,
instance.customlist,
instance.splittun,
instance.interop, // required to start before interception
// Grouped pausable interception modules:
+6
View File
@@ -132,6 +132,12 @@ func (e *Entity) DstPort() uint16 {
return e.dstPort
}
// FetchLocation fetches GeoIP location data for the entity.
// It is safe to call multiple times; the lookup runs only once.
func (e *Entity) FetchLocation(ctx context.Context) {
e.getLocation(ctx)
}
// FetchData fetches additional information, meant to be called before persisting an entity record.
func (e *Entity) FetchData(ctx context.Context) {
e.getLocation(ctx)
+2 -2
View File
@@ -54,7 +54,7 @@ func (i *InteropIvpn) onConnectionStopped(wc *mgr.WorkerCtx, _ string, _ string)
wc.Debug("IVPN: VPN connection stopped")
_ = i.ensureSPNCompatibility(wc)
i.reconcileCompatibilityState(wc)
}
// notification handler: VPN connection established successfully
@@ -73,5 +73,5 @@ func (i *InteropIvpn) onConnectedResp(wc *mgr.WorkerCtx, _ string, messageData s
wc.Debug(fmt.Sprintf("IVPN: VPN connection established (vpnType:%v; localIPv4:%v; localIPv6:%v)",
connectedResp.VpnType, connectedResp.ClientIP, connectedResp.ClientIPv6))
_ = i.ensureSPNCompatibility(wc)
i.reconcileCompatibilityState(wc)
}
+1 -2
View File
@@ -12,6 +12,5 @@ type platformSpecific struct{}
func (i *InteropIvpn) spnConnectingHook(w *mgr.WorkerCtx, hookCtx hub.Announcement) (cancel bool, err error) {
return true, nil
}
func (i *InteropIvpn) ensureSPNCompatibility(wc *mgr.WorkerCtx) error {
return nil
func (i *InteropIvpn) reconcileCompatibilityState(wc *mgr.WorkerCtx) {
}
+31 -29
View File
@@ -26,48 +26,49 @@ type platformSpecific struct {
const (
// NOTE: The nft table name is currently tied to IVPN's wg-quick setup.
// If IVPN changes the WG interface naming, this constant may need adjustment.
nftTableWgQuickIvpn = "wg-quick-wgivpn"
nftRuleCommentSPNCompat = "portmaster-spn-lo-rnat"
spnSlitTunRouteTableID = "717"
spnSlitTunRulePriority = "717"
nftTableWgQuickIvpn = "wg-quick-wgivpn"
wgKillswitchBypassComment = "portmaster-wg-ks-bypass"
spnSlitTunRouteTableID = "717"
spnSlitTunRulePriority = "717"
)
// spnConnectingHook is called when SPN is connecting to a hub.
func (i *InteropIvpn) spnConnectingHook(wc *mgr.WorkerCtx, homeHub hub.Announcement) (cancel bool, retErr error) {
err := i.ensureWgSpnCompatRule(wc)
i.doReconcileCompatibilityState(wc, &homeHub)
return false, nil
}
func (i *InteropIvpn) reconcileCompatibilityState(wc *mgr.WorkerCtx) {
i.doReconcileCompatibilityState(wc, i.extra.spnHubInfo.Load())
}
// doReconcileCompatibilityState reconciles WireGuard firewall and routing rules
// to maintain SPN and SplitTunnel compatibility with the current IVPN connection state.
func (i *InteropIvpn) doReconcileCompatibilityState(wc *mgr.WorkerCtx, hubInfo *hub.Announcement) {
// Ensure WireGuard-specific firewall rule is in place or removed as needed based on current VPN and SPN/ST state.
err := i.ensureWgCompatRule(wc)
if err != nil {
// Could happen, for example, if IVPN Client is paused
wc.Warn(fmt.Sprintf("IVPN: failed to ensure WireGuard compatibility rule: %v", err))
}
err = i.ensureSpnHubBypassVpnRoutes(wc, &homeHub)
// Ensure routing rules are in place to keep SPN hub traffic outside the VPN tunnel when connected,
// or clean up stale rules when disconnected.
err = i.ensureSpnHubBypassVpnRoutes(wc, hubInfo)
if err != nil {
wc.Warn(fmt.Sprintf("IVPN: failed to ensure VPN and SPN tunnel routes: %v", err))
}
return false, nil
}
func (i *InteropIvpn) ensureSPNCompatibility(wc *mgr.WorkerCtx) error {
err := i.ensureWgSpnCompatRule(wc)
if err != nil {
wc.Warn(fmt.Sprintf("IVPN: failed to ensure WireGuard compatibility rule: %v", err))
}
err = i.ensureSpnHubBypassVpnRoutes(wc, i.extra.spnHubInfo.Load())
if err != nil {
wc.Warn(fmt.Sprintf("IVPN: failed to ensure VPN and SPN tunnel routes: %v", err))
}
return nil
}
// SPN compatibility workaround for WireGuard kill-switch rules.
// SPN and SplitTunnel (ST) compatibility workaround for WireGuard kill-switch rules.
//
// WireGuard (wg-quick) installs a prerouting/raw kill-switch rule that drops
// packets destined to the WG local address when they arrive from non-WG interfaces.
// Portmaster SPN reverse-NAT replies are delivered via loopback (iif lo) with a
// Portmaster SPN/ST reverse-NAT replies are delivered via loopback (iif lo) with a
// non-local source, which matches that drop pattern and breaks the TCP handshake
// (SYN-SENT/SYN-RECV).
//
// To preserve the kill-switch behavior while allowing SPN reverse-NAT, Portmaster
// To preserve the kill-switch behavior while allowing SPN/ST reverse-NAT, Portmaster
// inserts a narrow exception rule before the wg-quick drop:
// - nft path (preferred):
// `iifname "lo" ip daddr <WG_LOCAL_IP> fib saddr type != local accept`
@@ -76,8 +77,8 @@ func (i *InteropIvpn) ensureSPNCompatibility(wc *mgr.WorkerCtx) error {
//
// Rule lifecycle is managed here:
// - Remove previously managed rule (nft/iptables) first.
// - Recreate only when WireGuard is connected and SPN is enabled.
func (i *InteropIvpn) ensureWgSpnCompatRule(wc *mgr.WorkerCtx) error {
// - Recreate only when WireGuard is connected and SPN/ST is enabled.
func (i *InteropIvpn) ensureWgCompatRule(wc *mgr.WorkerCtx) error {
status := i.getStatus()
connectedInfo := status.connectedInfo
@@ -103,7 +104,7 @@ func (i *InteropIvpn) ensureWgSpnCompatRule(wc *mgr.WorkerCtx) error {
"-d", oldRuleIP+"/32",
"-i", "lo",
"-m", "addrtype", "!", "--src-type", "LOCAL",
"-m", "comment", "--comment", nftRuleCommentSPNCompat,
"-m", "comment", "--comment", wgKillswitchBypassComment,
"-j", "ACCEPT",
).Run()
i.extra.spnWgIptRuleIP.Store("")
@@ -123,7 +124,8 @@ func (i *InteropIvpn) ensureWgSpnCompatRule(wc *mgr.WorkerCtx) error {
// If SPN not enabled -we do not need the rule
cfgSpnEnabled := config.GetAsBool("spn/enable", false)
if !cfgSpnEnabled() {
cfgSplittunEnabled := config.GetAsBool("splittun/enable", false)
if !cfgSpnEnabled() && !cfgSplittunEnabled() {
return nil
}
@@ -132,7 +134,7 @@ func (i *InteropIvpn) ensureWgSpnCompatRule(wc *mgr.WorkerCtx) error {
// sudo nft --echo --json insert rule ip wg-quick-wgivpn preraw iifname "lo" ip daddr 1.2.3.4 fib saddr type != local accept comment "portmaster-spn-lo-rnat"
out, err := exec.Command(nftPath, "--echo", "--json", "insert", "rule", "ip", nftTableWgQuickIvpn, "preraw",
"iifname", "lo", "ip", "daddr", wgLocalIP, "fib", "saddr", "type", "!=", "local", "accept",
"comment", nftRuleCommentSPNCompat).Output()
"comment", wgKillswitchBypassComment).Output()
if err != nil {
return fmt.Errorf("failed to insert nft rule: %w", err)
}
@@ -157,7 +159,7 @@ func (i *InteropIvpn) ensureWgSpnCompatRule(wc *mgr.WorkerCtx) error {
"-d", wgLocalIP+"/32",
"-i", "lo",
"-m", "addrtype", "!", "--src-type", "LOCAL",
"-m", "comment", "--comment", nftRuleCommentSPNCompat,
"-m", "comment", "--comment", wgKillswitchBypassComment,
"-j", "ACCEPT",
).Run()
if err != nil {
+1 -2
View File
@@ -28,6 +28,5 @@ func (i *InteropIvpn) spnConnectingHook(w *mgr.WorkerCtx, hookCtx hub.Announceme
return false, nil
}
func (i *InteropIvpn) ensureSPNCompatibility(wc *mgr.WorkerCtx) error {
return nil
func (i *InteropIvpn) reconcileCompatibilityState(wc *mgr.WorkerCtx) {
}
+6 -6
View File
@@ -125,8 +125,8 @@ func (i *InteropIvpn) connectIvpnClient(wc *mgr.WorkerCtx) error {
// Mark that the first connection attempt is done, even if it failed
i.setFirstTryDone()
// Ensure SPN compatibility rules are removed when Portmaster disconnects from IVPN client, either due to shutdown or connection failure.
i.ensureSPNCompatibility(wc)
// Ensure compatibility rules are removed when Portmaster disconnects from IVPN client, either due to shutdown or connection failure.
i.reconcileCompatibilityState(wc)
}()
notifWarn := notifWarnOldVersion.Load()
@@ -217,17 +217,17 @@ func (i *InteropIvpn) connectIvpnClient(wc *mgr.WorkerCtx) error {
i.updateIvpnClientDnsSettings(wc, client)
// Subscribe to interception start/stop events to update IVPN client DNS settings
interceptionStatus := i.owner.Interception().EventStartStopState.Subscribe("ivpn", 10)
defer interceptionStatus.Cancel()
interceptionStartStopStatus := i.owner.Interception().EventStartStopState.Subscribe("ivpn", 10)
defer interceptionStartStopStatus.Cancel()
done := false
for !done {
select {
case <-interceptionStatus.Events():
case <-interceptionStartStopStatus.Events():
i.updateIvpnClientDnsSettings(wc, client)
case <-i.owner.EvtConfigChange():
i.updateIvpnClientDnsSettings(wc, client)
i.ensureSPNCompatibility(wc)
i.reconcileCompatibilityState(wc)
case <-wc.Done():
client.Disconnect()
done = true
+2 -1
View File
@@ -205,7 +205,8 @@ func handleRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg)
switch conn.Verdict {
// We immediately save blocked, dropped or failed verdicts so
// they pop up in the UI.
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed, network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel:
case network.VerdictBlock, network.VerdictDrop, network.VerdictFailed,
network.VerdictRerouteToNameserver, network.VerdictRerouteToTunnel, network.VerdictRerouteToSplitTun:
conn.Save()
// For undecided or accepted connections we don't save them yet, because
+586
View File
@@ -0,0 +1,586 @@
package netenv
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/safing/portmaster/base/log"
"github.com/safing/portmaster/service/network/netutils"
)
// cachedNetInterface holds a network interface with its pre-parsed IP addresses.
type cachedNetInterface struct {
iface net.Interface
addrs []net.IP // all routable addresses; used for IP-based lookup
macStr string // HardwareAddr.String() result, cached to avoid per-search allocations
firstIPv4 net.IP // first routable IPv4, or nil; cached to avoid scanning addrs on every call
firstIPv6 net.IP // first routable IPv6, or nil; cached to avoid scanning addrs on every call
}
// InterfaceInfo holds a matched network interface with the preferred
// IP addresses for each address family.
//
// When the interface was found by a specific IP, that IP is used as the
// preferred address for its address family; the first address of the
// other family (if any) is populated from the interface's address list.
// When the interface was found by name or MAC, the first address of each
// family from the address list is used.
type InterfaceInfo struct {
Interface *net.Interface
IPv4 net.IP // first routable IPv4 address for this interface; nil if none
IPv6 net.IP // first routable IPv6 address for this interface; nil if none
}
var (
// ifaceCache stores the latest enumerated network interfaces as a slice.
// A slice is used instead of maps because a typical host has only a handful
// of interfaces (210). Linear scans over such small slices are faster than
// map lookups: no hashing, no bucket pointer chasing, and the data fits
// entirely in a few cache lines. Maps would also require three separate
// structures (by name, IP, MAC), adding allocation and maintenance cost with
// no measurable benefit at real-world sizes.
// It is nil until the first call to any GetInterface* function (lazy init).
ifaceCache []cachedNetInterface
ifaceCacheLock sync.RWMutex
ifaceCacheChangedFlag = GetNetworkChangedFlag()
ifaceCacheRefreshError error //nolint:errname // Not what the linter thinks this is for.
ifaceCacheDontRefreshUntil time.Time
)
// refreshIfaceCache re-enumerates all network interfaces and stores them in ifaceCache.
// It also resets the network-changed flag.
// Refreshes are throttled to at most once per second to avoid redundant
// re-enumerations during rapid interface churn (e.g. network reconnects).
// The caller must hold ifaceCacheLock for writing.
func refreshIfaceCache() error {
// Throttle: return early if we refreshed very recently; the existing cache remains valid.
if time.Now().Before(ifaceCacheDontRefreshUntil) {
if ifaceCacheRefreshError != nil {
return fmt.Errorf("failed to previously refresh interface cache: %w", ifaceCacheRefreshError)
}
return nil
}
ifaceCacheRefreshError = nil
ifaceCacheDontRefreshUntil = time.Now().Add(1 * time.Second)
ifaces, err := net.Interfaces()
if err != nil {
ifaceCacheRefreshError = err
return fmt.Errorf("failed to enumerate network interfaces: %w", err)
}
newCache := make([]cachedNetInterface, 0, len(ifaces))
for i := range ifaces {
// Skip interfaces that are down — they have no usable IP connectivity.
if ifaces[i].Flags&net.FlagUp == 0 {
continue
}
// Skip loopback — it is not useful for cross-host communication.
if ifaces[i].Flags&net.FlagLoopback != 0 {
continue
}
entry := cachedNetInterface{
iface: ifaces[i],
macStr: ifaces[i].HardwareAddr.String(),
}
addrs, addrErr := ifaces[i].Addrs()
if addrErr != nil {
log.Warningf("netenv: failed to get addresses for interface %s: %v", ifaces[i].Name, addrErr)
} else {
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
// Skip addresses of unexpected types (switch default left ip nil).
if ip == nil {
continue
}
// Use the 4-byte form for IPv4 so it matches what was stored during cache build.
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
// Keep only routable unicast addresses (site-local or global).
if !isRoutableUnicastIP(ip) {
continue
}
entry.addrs = append(entry.addrs, ip)
}
}
// Skip interfaces with no usable unicast addresses — they cannot
// participate in normal IP connectivity and are not searchable by IP.
if len(entry.addrs) == 0 {
continue
}
// Pre-cache the first address of each family so buildInterfaceInfo
// can return them with two field reads instead of scanning addrs.
for _, ip := range entry.addrs {
if ip.To4() != nil {
if entry.firstIPv4 == nil {
entry.firstIPv4 = ip
}
} else {
if entry.firstIPv6 == nil {
entry.firstIPv6 = ip
}
}
if entry.firstIPv4 != nil && entry.firstIPv6 != nil {
break
}
}
newCache = append(newCache, entry)
}
ifaceCache = newCache
ifaceCacheChangedFlag.Refresh()
return nil
}
// ensureIfaceCache guarantees the cache is populated and up to date.
// The caller must hold ifaceCacheLock for writing.
func ensureIfaceCache() error {
if ifaceCache == nil || ifaceCacheChangedFlag.IsSet() {
return refreshIfaceCache()
}
return nil
}
// cacheReady reports whether the cache is populated and current.
// The caller must hold at least ifaceCacheLock for reading.
func cacheReady() bool {
return ifaceCache != nil && !ifaceCacheChangedFlag.IsSet()
}
// buildInterfaceInfo constructs an InterfaceInfo from a cache entry.
// If knownIP is non-nil it is used as the preferred address for its
// address family; the pre-cached first address of the other family fills
// the remaining field. Both fields are read directly from the entry — no
// scan of addrs is needed.
func buildInterfaceInfo(entry *cachedNetInterface, knownIP net.IP) *InterfaceInfo {
info := &InterfaceInfo{
Interface: &entry.iface,
IPv4: entry.firstIPv4,
IPv6: entry.firstIPv6,
}
// Override the matched family with the exact IP used to find this entry.
if knownIP != nil {
if knownIP.To4() != nil {
info.IPv4 = knownIP
} else {
info.IPv6 = knownIP
}
}
return info
}
// GetInterface returns the local network interface identified by ifinfo.
// ifinfo may be an IP address, a MAC address, or an interface name; they are
// tried in that order. An error is returned when no interface matches.
func GetInterface(ifinfo string) (*InterfaceInfo, error) {
// Fast path: concurrent reads when the cache is already valid.
ifaceCacheLock.RLock()
if cacheReady() {
entry, matchedIP := searchByIfinfo(ifinfo)
ifaceCacheLock.RUnlock()
if entry == nil {
return nil, fmt.Errorf("no interface found %q", ifinfo)
}
return buildInterfaceInfo(entry, matchedIP), nil
}
ifaceCacheLock.RUnlock()
// Slow path: refresh the cache, then search.
ifaceCacheLock.Lock()
defer ifaceCacheLock.Unlock()
if err := ensureIfaceCache(); err != nil {
return nil, err
}
entry, matchedIP := searchByIfinfo(ifinfo)
if entry == nil {
return nil, fmt.Errorf("no interface found %q", ifinfo)
}
return buildInterfaceInfo(entry, matchedIP), nil
}
// searchByIfinfo searches ifaceCache in priority order: IP → MAC → name.
// It returns the matched cache entry and, when the match was by IP, the
// normalised IP that was used (so the caller can pin it as the preferred
// address for that family). The IP return value is nil for name/MAC matches.
// The caller must hold ifaceCacheLock (for reading or writing).
func searchByIfinfo(ifinfo string) (*cachedNetInterface, net.IP) {
if ip := net.ParseIP(ifinfo); ip != nil {
normalized := normalizeIP(ip)
return searchIfaceByIP(normalized), normalized
}
if mac, err := net.ParseMAC(ifinfo); err == nil {
return searchIfaceByMAC(mac.String()), nil
}
return searchIfaceByName(ifinfo), nil
}
// GetInterfaceByIP returns the local network interface that has ip assigned.
func GetInterfaceByIP(ip net.IP) (*InterfaceInfo, error) {
if ip == nil {
return nil, fmt.Errorf("GetInterfaceByIP called with nil IP")
}
normalized := normalizeIP(ip)
ifaceCacheLock.RLock()
if cacheReady() {
entry := searchIfaceByIP(normalized)
ifaceCacheLock.RUnlock()
if entry == nil {
return nil, fmt.Errorf("no interface found with IP %s", ip)
}
return buildInterfaceInfo(entry, normalized), nil
}
ifaceCacheLock.RUnlock()
ifaceCacheLock.Lock()
defer ifaceCacheLock.Unlock()
if err := ensureIfaceCache(); err != nil {
return nil, err
}
if entry := searchIfaceByIP(normalized); entry != nil {
return buildInterfaceInfo(entry, normalized), nil
}
return nil, fmt.Errorf("no interface found with IP %s", ip)
}
// GetInterfaceByMAC returns the local network interface with the given hardware address.
func GetInterfaceByMAC(mac net.HardwareAddr) (*InterfaceInfo, error) {
macStr := mac.String()
ifaceCacheLock.RLock()
if cacheReady() {
entry := searchIfaceByMAC(macStr)
ifaceCacheLock.RUnlock()
if entry == nil {
return nil, fmt.Errorf("no interface found with MAC %s", mac)
}
return buildInterfaceInfo(entry, nil), nil
}
ifaceCacheLock.RUnlock()
ifaceCacheLock.Lock()
defer ifaceCacheLock.Unlock()
if err := ensureIfaceCache(); err != nil {
return nil, err
}
if entry := searchIfaceByMAC(macStr); entry != nil {
return buildInterfaceInfo(entry, nil), nil
}
return nil, fmt.Errorf("no interface found with MAC %s", mac)
}
// GetInterfaceByName returns the local network interface with the given name.
func GetInterfaceByName(name string) (*InterfaceInfo, error) {
ifaceCacheLock.RLock()
if cacheReady() {
entry := searchIfaceByName(name)
ifaceCacheLock.RUnlock()
if entry == nil {
return nil, fmt.Errorf("no interface found with name %q", name)
}
return buildInterfaceInfo(entry, nil), nil
}
ifaceCacheLock.RUnlock()
ifaceCacheLock.Lock()
defer ifaceCacheLock.Unlock()
if err := ensureIfaceCache(); err != nil {
return nil, err
}
if entry := searchIfaceByName(name); entry != nil {
return buildInterfaceInfo(entry, nil), nil
}
return nil, fmt.Errorf("no interface found with name %q", name)
}
// normalizeIP returns the 4-byte form of an IPv4 address, or the IP unchanged
// for IPv6. This matches the form stored in cachedNetInterface.addrs.
func normalizeIP(ip net.IP) net.IP {
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
// isRoutableUnicastIP reports whether ip is a routable unicast address useful
// for real network communication: site-local (RFC 1918 / ULA) or globally
// routable. All other scopes — link-local, loopback, unspecified, multicast,
// and documentation/test ranges — are excluded.
// ip must already be in its canonical form (4-byte for IPv4, see normalizeIP).
func isRoutableUnicastIP(ip net.IP) bool {
scope := netutils.GetIPScope(ip)
return scope == netutils.SiteLocal || scope == netutils.Global
}
// searchIfaceByIP returns the cache entry whose address list contains ip, or nil.
// The caller must hold ifaceCacheLock for reading or writing.
func searchIfaceByIP(ip net.IP) *cachedNetInterface {
for i := range ifaceCache {
for _, addr := range ifaceCache[i].addrs {
if ip.Equal(addr) {
return &ifaceCache[i]
}
}
}
return nil
}
// searchIfaceByMAC returns the cache entry whose hardware address matches
// macStr (in canonical net.HardwareAddr.String() form), or nil.
// The caller must hold ifaceCacheLock for reading or writing.
func searchIfaceByMAC(macStr string) *cachedNetInterface {
for i := range ifaceCache {
if ifaceCache[i].macStr == macStr {
return &ifaceCache[i]
}
}
return nil
}
// searchIfaceByName returns the cache entry with the given name, or nil.
// The caller must hold ifaceCacheLock for reading or writing.
func searchIfaceByName(name string) *cachedNetInterface {
for i := range ifaceCache {
if ifaceCache[i].iface.Name == name {
return &ifaceCache[i]
}
}
return nil
}
// ---- Physical default interface ----
// PhysicalDefaultInterfaces holds the best physical network adapter per IP
// family, together with its preferred bind addresses. IPv4 and IPv6 may
// resolve to different interfaces — for example when a VPN tunnels only IPv4
// and IPv6 traffic exits directly on Ethernet, or when Ethernet serves IPv4
// and a mobile hotspot provides IPv6.
// A nil field means no physical interface with a default route for that family
// was found (e.g. IPv6 is simply not configured on this host).
type PhysicalDefaultInterfaces struct {
ForIPv4 *InterfaceInfo // best physical adapter handling the default IPv4 route; nil if none
ForIPv6 *InterfaceInfo // best physical adapter handling the default IPv6 route; nil if none
}
var (
physicalDefaultIfacesCache PhysicalDefaultInterfaces
physicalDefaultIfacesCacheValid bool
physicalDefaultIfacesLock sync.RWMutex
physicalDefaultIfacesChangedFlag = GetNetworkChangedFlag()
physicalDefaultIfacesDontRefreshUntil time.Time
)
// GetBestPhysicalDefaultInterfaces returns the physical network adapters
// (Ethernet, WiFi, mobile broadband) currently used for internet traffic,
// one per IP family. VPN, tunnel, and other virtual interfaces are explicitly
// excluded, making the result safe to use as split-tunnel bypass targets.
//
// Selection criteria per family (all must be satisfied):
// - Adapter type is physical hardware (Ethernet, WiFi, mobile broadband).
// - Has at least one routable unicast address for that family (not link-local).
// - Has a default route (0.0.0.0/0 or ::/0) in the routing table.
// - When multiple candidates qualify, the one with the lowest route metric wins.
//
// The result is cached and refreshed only when a network change is detected,
// making it safe to call on every new connection without performance overhead.
func GetBestPhysicalDefaultInterfaces() (PhysicalDefaultInterfaces, error) {
// Fast path: concurrent reads when cache is valid.
physicalDefaultIfacesLock.RLock()
if physicalDefaultIfacesCacheValid && !physicalDefaultIfacesChangedFlag.IsSet() {
result := physicalDefaultIfacesCache
physicalDefaultIfacesLock.RUnlock()
return result, nil
}
physicalDefaultIfacesLock.RUnlock()
// Slow path: refresh under write lock.
physicalDefaultIfacesLock.Lock()
defer physicalDefaultIfacesLock.Unlock()
// Re-check: another goroutine may have refreshed while we waited.
if physicalDefaultIfacesCacheValid && !physicalDefaultIfacesChangedFlag.IsSet() {
return physicalDefaultIfacesCache, nil
}
// Throttle: if a refresh just ran, return the cached result even if the
// change flag fired again — avoids hammering the OS during interface churn.
if physicalDefaultIfacesCacheValid && time.Now().Before(physicalDefaultIfacesDontRefreshUntil) {
return physicalDefaultIfacesCache, nil
}
physicalDefaultIfacesDontRefreshUntil = time.Now().Add(1 * time.Second)
// Consume the change flag before the (potentially slow) platform call so
// any change that arrives during the call will trigger a re-evaluation.
physicalDefaultIfacesChangedFlag.Refresh()
ipv4Iface, ipv6Iface, err := selectPhysicalDefaultInterfaces()
if err != nil {
physicalDefaultIfacesCacheValid = false
return PhysicalDefaultInterfaces{}, err
}
result := PhysicalDefaultInterfaces{
ForIPv4: interfaceToInfo(ipv4Iface),
ForIPv6: interfaceToInfo(ipv6Iface),
}
physicalDefaultIfacesCache = result
physicalDefaultIfacesCacheValid = true
return result, nil
}
// interfaceToInfo looks up iface in the interface cache and returns an
// InterfaceInfo populated with the first routable address per family.
// Falls back to scanning iface.Addrs() directly when the cache is unavailable
// (e.g. a transient failure during network churn), so IPv4/IPv6 are always
// populated when addresses exist on the interface.
// The caller must NOT hold ifaceCacheLock.
func interfaceToInfo(iface *net.Interface) *InterfaceInfo {
if iface == nil {
return nil
}
ifaceCacheLock.RLock()
if cacheReady() {
entry := searchIfaceByName(iface.Name)
ifaceCacheLock.RUnlock()
if entry != nil {
return buildInterfaceInfo(entry, nil)
}
// Interface not in cache yet (added after last refresh); fall through.
} else {
ifaceCacheLock.RUnlock()
ifaceCacheLock.Lock()
if err := ensureIfaceCache(); err == nil {
if entry := searchIfaceByName(iface.Name); entry != nil {
result := buildInterfaceInfo(entry, nil)
ifaceCacheLock.Unlock()
return result
}
}
ifaceCacheLock.Unlock()
}
// Cache unavailable or interface not present in it — populate addresses
// directly from the kernel so IPv4/IPv6 are never silently nil.
return buildInterfaceInfoDirect(iface)
}
// buildInterfaceInfoDirect constructs an InterfaceInfo by calling iface.Addrs()
// directly, without using the cache. Used as a fallback when the cache is
// unavailable. Uses the same isRoutableUnicastIP predicate as refreshIfaceCache.
func buildInterfaceInfoDirect(iface *net.Interface) *InterfaceInfo {
info := &InterfaceInfo{Interface: iface}
addrs, err := iface.Addrs()
if err != nil {
return info
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil {
continue
}
// Normalize to 4-byte form so isRoutableUnicastIP and family checks are consistent.
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
if !isRoutableUnicastIP(ip) {
continue
}
if ip.To4() != nil {
if info.IPv4 == nil {
info.IPv4 = ip
}
} else {
if info.IPv6 == nil {
info.IPv6 = ip
}
}
if info.IPv4 != nil && info.IPv6 != nil {
break
}
}
return info
}
// hasRoutableIPv4 reports whether iface has at least one unicast IPv4 address
// that is globally routable — not unspecified (0.0.0.0), loopback (127.x.x.x),
// or link-local/APIPA (169.254.x.x).
//
// An interface may be physically present and have a default route while still
// lacking a usable IP (DHCP not completed, cable just reconnected, etc.).
// Checking the address is the final confirmation that the interface can
// actually forward packets.
func hasRoutableIPv4(iface *net.Interface) bool {
addrs, err := iface.Addrs()
if err != nil {
return false
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
ip4 := ip.To4()
if ip4 == nil {
continue
}
if isRoutableUnicastIP(ip4) {
return true
}
}
return false
}
// hasRoutableIPv6 reports whether iface has at least one unicast IPv6 address
// that is globally routable — not unspecified (::), loopback (::1),
// or link-local (fe80::/10).
func hasRoutableIPv6(iface *net.Interface) bool {
addrs, err := iface.Addrs()
if err != nil {
return false
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
// Skip IPv4 addresses and nil.
if ip == nil || ip.To4() != nil {
continue
}
if isRoutableUnicastIP(ip) {
return true
}
}
return false
}
// errNoPhysicalDefaultInterface is returned by unsupported platform stubs.
var errNoPhysicalDefaultInterface = errors.New("physical network interface detection is not supported on this platform")
+10
View File
@@ -0,0 +1,10 @@
//go:build !linux && !windows
package netenv
import "net"
// selectPhysicalDefaultInterfaces is not implemented on this platform.
func selectPhysicalDefaultInterfaces() (*net.Interface, *net.Interface, error) {
return nil, nil, errNoPhysicalDefaultInterface
}
+157
View File
@@ -0,0 +1,157 @@
//go:build linux
package netenv
import (
"bufio"
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
)
// selectPhysicalDefaultInterfaces finds the best physical adapter per IP family
// that carries the default route, excluding all virtual and tunnel interfaces.
//
// Physical detection: the kernel creates /sys/class/net/<name>/device only for
// adapters bound to a real hardware driver. Virtual interfaces (tun, tap,
// bridge, veth, wireguard) never have this entry — this is the most reliable
// VPN-exclusion signal available without elevated privileges.
//
// IPv4 routes: /proc/net/route — always readable without root; provides
// destination, mask, and metric (decimal) for every IPv4 route.
//
// IPv6 routes: /proc/net/ipv6_route — same access requirements; provides
// destination, prefix length, next hop, and metric (hex) for every IPv6 route.
func selectPhysicalDefaultInterfaces() (*net.Interface, *net.Interface, error) {
type candidate struct {
name string
metric uint32
}
var v4candidates, v6candidates []candidate
// --- IPv4: read /proc/net/route ---
// Columns: Iface Dest Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
// Dest and Mask are 4-byte values in 8 hex chars, little-endian. Metric is decimal.
f4, err := os.Open("/proc/net/route")
if err != nil {
return nil, nil, fmt.Errorf("reading IPv4 routing table: %w", err)
}
defer f4.Close() //nolint:errcheck
scanner4 := bufio.NewScanner(f4)
scanner4.Scan() // skip header
for scanner4.Scan() {
fields := strings.Fields(scanner4.Text())
if len(fields) < 8 {
continue
}
dest, err := hex.DecodeString(fields[1])
if err != nil || len(dest) != 4 {
continue
}
mask, err := hex.DecodeString(fields[7])
if err != nil || len(mask) != 4 {
continue
}
// Default route: 0.0.0.0/0
if binary.LittleEndian.Uint32(dest) != 0 || binary.LittleEndian.Uint32(mask) != 0 {
continue
}
name := fields[0]
if !isSysfsPhysical(name) {
continue
}
// Metric column is decimal.
metric, err := strconv.ParseUint(fields[6], 10, 32)
if err != nil {
continue
}
v4candidates = append(v4candidates, candidate{name, uint32(metric)})
}
if err := scanner4.Err(); err != nil {
return nil, nil, fmt.Errorf("scanning IPv4 routing table: %w", err)
}
// --- IPv6: read /proc/net/ipv6_route ---
// Columns: dest destpfxlen src srcpfxlen nexthop metric refcnt use flags iface
// All addresses are 32 hex chars (no colons). Metric is hex. Iface is last.
f6, err := os.Open("/proc/net/ipv6_route")
if err == nil {
defer f6.Close() //nolint:errcheck
scanner6 := bufio.NewScanner(f6)
for scanner6.Scan() {
fields := strings.Fields(scanner6.Text())
if len(fields) < 10 {
continue
}
// Default route: destination = ::/0
if fields[0] != "00000000000000000000000000000000" {
continue
}
pfxLen, err := strconv.ParseUint(fields[1], 16, 8)
if err != nil || pfxLen != 0 {
continue
}
// Skip on-link entries that have no actual gateway.
if fields[4] == "00000000000000000000000000000000" {
continue
}
name := fields[len(fields)-1]
if !isSysfsPhysical(name) {
continue
}
// Metric is hex in ipv6_route (unlike decimal in /proc/net/route).
metric, err := strconv.ParseUint(fields[5], 16, 32)
if err != nil {
continue
}
v6candidates = append(v6candidates, candidate{name, uint32(metric)})
}
// IPv6 scanner errors are non-fatal — leave result.IPv6 as nil.
}
// If /proc/net/ipv6_route is absent, IPv6 is not configured; that is not an error.
// Pick the lowest-metric candidate per family that also has a routable address,
// confirming DHCP/SLAAC has completed and the interface is actively communicating.
var ipv4Iface, ipv6Iface *net.Interface
var bestV4Metric, bestV6Metric uint32
for _, c := range v4candidates {
iface, err := net.InterfaceByName(c.name)
if err != nil || !hasRoutableIPv4(iface) {
continue
}
if ipv4Iface == nil || c.metric < bestV4Metric {
ipv4Iface = iface
bestV4Metric = c.metric
}
}
for _, c := range v6candidates {
iface, err := net.InterfaceByName(c.name)
if err != nil || !hasRoutableIPv6(iface) {
continue
}
if ipv6Iface == nil || c.metric < bestV6Metric {
ipv6Iface = iface
bestV6Metric = c.metric
}
}
return ipv4Iface, ipv6Iface, nil
}
// isSysfsPhysical reports whether the named interface is backed by a real
// hardware driver. The kernel creates /sys/class/net/<name>/device only for
// adapters bound to an actual device driver (PCI/USB Ethernet, wireless card).
// Virtual interfaces — tun, tap, bridge, veth, wireguard, loopback — never
// have this sysfs entry.
func isSysfsPhysical(name string) bool {
_, err := os.Stat("/sys/class/net/" + name + "/device")
return err == nil
}
+471
View File
@@ -0,0 +1,471 @@
package netenv
import (
"net"
"testing"
"github.com/safing/portmaster/service/network/netutils"
)
// isRoutableIP returns true for IPs that the cache keeps: site-local or global.
// Matches the isRoutableUnicastIP predicate used in production code.
func isRoutableIP(ip net.IP) bool {
if ip == nil {
return false
}
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
scope := netutils.GetIPScope(ip)
return scope == netutils.SiteLocal || scope == netutils.Global
}
// getTestInterface picks the first network interface that matches the same
// criteria as the cache: FlagUp and at least one routable (non-link-local)
// unicast address. Falls back to loopback if no other candidate is found.
func getTestInterface(t *testing.T) net.Interface {
t.Helper()
ifaces, err := net.Interfaces()
if err != nil {
t.Fatalf("net.Interfaces() failed: %v", err)
}
for i := range ifaces {
iface := ifaces[i]
if iface.Flags&net.FlagUp == 0 {
continue
}
// Mirror the cache filter: loopback is excluded.
if iface.Flags&net.FlagLoopback != 0 {
continue
}
addrs, _ := iface.Addrs()
hasRoutable := false
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if isRoutableIP(ip) {
hasRoutable = true
break
}
}
if !hasRoutable {
continue
}
return iface
}
t.Skip("no usable non-loopback network interface found skipping test")
panic("unreachable")
}
// firstRoutableIP returns the first routable (non-link-local) unicast IP
// assigned to iface, or nil if none exists.
func firstRoutableIP(iface net.Interface) net.IP {
addrs, _ := iface.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if isRoutableIP(ip) {
return ip
}
}
return nil
}
// firstRoutableIPv4 returns the first routable IPv4 address on iface, or nil.
func firstRoutableIPv4(iface net.Interface) net.IP {
addrs, _ := iface.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if isRoutableIP(ip) {
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
}
}
return nil
}
// firstRoutableIPv6 returns the first routable IPv6 address on iface, or nil.
func firstRoutableIPv6(iface net.Interface) net.IP {
addrs, _ := iface.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if isRoutableIP(ip) && ip.To4() == nil {
return ip
}
}
return nil
}
// ---- GetInterfaceByName -------------------------------------------------------
func TestGetInterfaceByName(t *testing.T) {
t.Parallel()
want := getTestInterface(t)
got, err := GetInterfaceByName(want.Name)
if err != nil {
t.Fatalf("GetInterfaceByName(%q): unexpected error: %v", want.Name, err)
}
if got.Interface.Name != want.Name {
t.Errorf("GetInterfaceByName(%q): got %q", want.Name, got.Interface.Name)
}
}
func TestGetInterfaceByName_NotFound(t *testing.T) {
t.Parallel()
_, err := GetInterfaceByName("__no_such_interface__")
if err == nil {
t.Fatal("expected error for unknown interface name, got nil")
}
}
// ---- GetInterfaceByIP --------------------------------------------------------
func TestGetInterfaceByIP(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
ip := firstRoutableIP(iface)
if ip == nil {
t.Skipf("interface %q has no routable address skipping", iface.Name)
}
got, err := GetInterfaceByIP(ip)
if err != nil {
t.Fatalf("GetInterfaceByIP(%s): unexpected error: %v", ip, err)
}
if got.Interface.Name != iface.Name {
t.Errorf("GetInterfaceByIP(%s): got interface %q, want %q", ip, got.Interface.Name, iface.Name)
}
}
func TestGetInterfaceByIP_NotFound(t *testing.T) {
t.Parallel()
// 192.0.2.0/24 is TEST-NET-1 (RFC 5737) never assigned on a real host.
ip := net.ParseIP("192.0.2.1")
_, err := GetInterfaceByIP(ip)
if err == nil {
t.Fatal("expected error for unassigned IP, got nil")
}
}
// ---- GetInterfaceByMAC -------------------------------------------------------
func TestGetInterfaceByMAC(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
if len(iface.HardwareAddr) == 0 {
t.Skipf("interface %q has no hardware address skipping", iface.Name)
}
got, err := GetInterfaceByMAC(iface.HardwareAddr)
if err != nil {
t.Fatalf("GetInterfaceByMAC(%s): unexpected error: %v", iface.HardwareAddr, err)
}
if got.Interface.Name != iface.Name {
t.Errorf("GetInterfaceByMAC(%s): got interface %q, want %q",
iface.HardwareAddr, got.Interface.Name, iface.Name)
}
}
// ---- GetInterface (multi-mode) -----------------------------------------------
func TestGetInterface_ByName(t *testing.T) {
t.Parallel()
want := getTestInterface(t)
got, err := GetInterface(want.Name)
if err != nil {
t.Fatalf("GetInterface(%q) by name: unexpected error: %v", want.Name, err)
}
if got.Interface.Name != want.Name {
t.Errorf("GetInterface(%q): got %q", want.Name, got.Interface.Name)
}
}
func TestGetInterface_ByIP(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
ip := firstRoutableIP(iface)
if ip == nil {
t.Skipf("interface %q has no routable address skipping", iface.Name)
}
ipStr := ip.String()
got, err := GetInterface(ipStr)
if err != nil {
t.Fatalf("GetInterface(%q) by IP: unexpected error: %v", ipStr, err)
}
if got.Interface.Name != iface.Name {
t.Errorf("GetInterface(%q): got %q, want %q", ipStr, got.Interface.Name, iface.Name)
}
}
func TestGetInterface_ByMAC(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
if len(iface.HardwareAddr) == 0 {
t.Skipf("interface %q has no hardware address skipping", iface.Name)
}
macStr := iface.HardwareAddr.String()
got, err := GetInterface(macStr)
if err != nil {
t.Fatalf("GetInterface(%q) by MAC: unexpected error: %v", macStr, err)
}
if got.Interface.Name != iface.Name {
t.Errorf("GetInterface(%q): got %q, want %q", macStr, got.Interface.Name, iface.Name)
}
}
func TestGetInterface_NotFound(t *testing.T) {
t.Parallel()
_, err := GetInterface("__no_such_interface__")
if err == nil {
t.Fatal("expected error for unrecognised ifinfo, got nil")
}
}
// TestGetInterfaceByIP_LinkLocalIPv6 verifies that IPv6 link-local addresses
// are filtered out of the cache and therefore never match a lookup.
func TestGetInterfaceByIP_LinkLocalIPv6(t *testing.T) {
t.Parallel()
ip := net.ParseIP("fe80::1")
_, err := GetInterfaceByIP(ip)
if err == nil {
t.Error("expected error for link-local IP fe80::1, got nil")
}
}
// TestGetInterfaceByIP_LinkLocalIPv4 verifies that IPv4 link-local addresses
// (APIPA range 169.254.x.x) are filtered out of the cache.
func TestGetInterfaceByIP_LinkLocalIPv4(t *testing.T) {
t.Parallel()
ip := net.ParseIP("169.254.0.1")
_, err := GetInterfaceByIP(ip)
if err == nil {
t.Error("expected error for link-local IP 169.254.0.1, got nil")
}
}
// TestGetInterface_RepeatedCall verifies that repeated calls with the same
// argument succeed consistently (exercises the list cache path).
func TestGetInterface_RepeatedCall(t *testing.T) {
t.Parallel()
want := getTestInterface(t)
got1, err := GetInterface(want.Name)
if err != nil {
t.Fatalf("first GetInterface(%q): %v", want.Name, err)
}
got2, err := GetInterface(want.Name)
if err != nil {
t.Fatalf("second GetInterface(%q): %v", want.Name, err)
}
if got1.Interface.Name != got2.Interface.Name {
t.Errorf("inconsistent results: got %q then %q", got1.Interface.Name, got2.Interface.Name)
}
}
// ---- InterfaceInfo bind-address fields ---------------------------------------
// TestGetInterfaceByIP_MatchedIPv4InInfo verifies that when an interface is
// found by an IPv4 address, that exact IP is returned in InterfaceInfo.IPv4.
func TestGetInterfaceByIP_MatchedIPv4InInfo(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
ip := firstRoutableIPv4(iface)
if ip == nil {
t.Skipf("interface %q has no routable IPv4 address skipping", iface.Name)
}
info, err := GetInterfaceByIP(ip)
if err != nil {
t.Fatalf("GetInterfaceByIP(%s): unexpected error: %v", ip, err)
}
if !info.IPv4.Equal(ip) {
t.Errorf("InterfaceInfo.IPv4: got %s, want %s", info.IPv4, ip)
}
}
// TestGetInterfaceByIP_MatchedIPv6InInfo verifies that when an interface is
// found by an IPv6 address, that exact IP is returned in InterfaceInfo.IPv6.
func TestGetInterfaceByIP_MatchedIPv6InInfo(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
ip := firstRoutableIPv6(iface)
if ip == nil {
t.Skipf("interface %q has no routable IPv6 address skipping", iface.Name)
}
info, err := GetInterfaceByIP(ip)
if err != nil {
t.Fatalf("GetInterfaceByIP(%s): unexpected error: %v", ip, err)
}
if !info.IPv6.Equal(ip) {
t.Errorf("InterfaceInfo.IPv6: got %s, want %s", info.IPv6, ip)
}
}
// TestGetInterfaceByName_IPv4InInfo verifies that when an interface is found
// by name, InterfaceInfo.IPv4 is populated with the first routable IPv4 address.
func TestGetInterfaceByName_IPv4InInfo(t *testing.T) {
t.Parallel()
iface := getTestInterface(t)
expectedIPv4 := firstRoutableIPv4(iface)
if expectedIPv4 == nil {
t.Skipf("interface %q has no routable IPv4 address skipping", iface.Name)
}
info, err := GetInterfaceByName(iface.Name)
if err != nil {
t.Fatalf("GetInterfaceByName(%q): unexpected error: %v", iface.Name, err)
}
if !info.IPv4.Equal(expectedIPv4) {
t.Errorf("InterfaceInfo.IPv4: got %s, want %s", info.IPv4, expectedIPv4)
}
}
// ---- Helper functions for logging -------------------------------------------------------
// logInterfaceInfo logs IPv4 and IPv6 interface info from PhysicalDefaultInterfaces.
func logInterfaceInfo(t *testing.T, label string, result PhysicalDefaultInterfaces) {
logIP := func(version string, info *InterfaceInfo) {
if info == nil {
t.Logf("%s - %s: <nil>", label, version)
return
}
var ip net.IP
if version == "IPv4" {
ip = info.IPv4
} else {
ip = info.IPv6
}
name := info.Interface.Name
if ip != nil {
t.Logf("%s - %s: %s (%s)", label, version, name, ip)
} else {
t.Logf("%s - %s: %s", label, version, name)
}
}
logIP("IPv4", result.ForIPv4)
logIP("IPv6", result.ForIPv6)
}
// ---- GetBestPhysicalDefaultInterfaces() -----------------------------------------------------
// TestGetBestPhysicalDefaultInterfaces verifies that GetBestPhysicalDefaultInterfaces
// returns at least one valid physical interface and that each non-nil result
// has a routable address for its respective family.
func TestGetBestPhysicalDefaultInterfaces(t *testing.T) {
t.Parallel()
result, err := GetBestPhysicalDefaultInterfaces()
if err != nil {
t.Fatalf("GetBestPhysicalDefaultInterfaces: unexpected error: %v", err)
}
// Print found interfaces
logInterfaceInfo(t, "Result", PhysicalDefaultInterfaces{ForIPv4: result.ForIPv4, ForIPv6: result.ForIPv6})
// At least one family must be resolved on any connected machine.
if result.ForIPv4 == nil && result.ForIPv6 == nil {
t.Fatal("GetBestPhysicalDefaultInterfaces: both ForIPv4 and ForIPv6 are nil; expected at least one")
}
if result.ForIPv4 != nil && !hasRoutableIPv4(result.ForIPv4.Interface) {
t.Errorf("GetBestPhysicalDefaultInterfaces: ForIPv4 interface %q has no routable IPv4 address", result.ForIPv4.Interface.Name)
}
if result.ForIPv6 != nil && !hasRoutableIPv6(result.ForIPv6.Interface) {
t.Errorf("GetBestPhysicalDefaultInterfaces: ForIPv6 interface %q has no routable IPv6 address", result.ForIPv6.Interface.Name)
}
}
// TestGetBestPhysicalDefaultInterfaces_Repeated verifies that repeated calls
// return consistent results (exercises the cache fast-path).
func TestGetBestPhysicalDefaultInterfaces_Repeated(t *testing.T) {
t.Parallel()
first, err := GetBestPhysicalDefaultInterfaces()
if err != nil {
t.Fatalf("first call: %v", err)
}
second, err := GetBestPhysicalDefaultInterfaces()
if err != nil {
t.Fatalf("second call: %v", err)
}
firstName4 := ifaceName(first.ForIPv4)
firstName6 := ifaceName(first.ForIPv6)
secondName4 := ifaceName(second.ForIPv4)
secondName6 := ifaceName(second.ForIPv6)
// Print found interfaces from both calls
logInterfaceInfo(t, "First call", first)
logInterfaceInfo(t, "Second call", second)
if firstName4 != secondName4 {
t.Errorf("IPv4: inconsistent results across calls: %q then %q", firstName4, secondName4)
}
if firstName6 != secondName6 {
t.Errorf("IPv6: inconsistent results across calls: %q then %q", firstName6, secondName6)
}
}
// ifaceName returns the interface name or "<nil>" for a nil InterfaceInfo.
// Used to produce readable test failure messages.
func ifaceName(info *InterfaceInfo) string {
if info == nil {
return "<nil>"
}
return info.Interface.Name
}
+133
View File
@@ -0,0 +1,133 @@
//go:build windows
package netenv
import (
"fmt"
"net"
"unsafe"
"golang.org/x/sys/windows"
)
// Windows IANA ifType constants.
// https://www.iana.org/assignments/ianaiftype-mib/ianaiftype-mib
//
// Only types that represent real physical hardware used for internet access
// are listed. Types that look physical but are excluded with justification:
// - IF_TYPE_GIGABITETHERNET (117): Windows drivers report GbE/10GbE as
// ETHERNET_CSMACD (6) at the NDIS level; 117 is never seen in practice.
// - IF_TYPE_PPP (23): shared by both dial-up modems and PPTP/PPPoE VPNs —
// too ambiguous to include safely.
// - IF_TYPE_USB (160): USB Ethernet dongles register as ETHERNET_CSMACD (6)
// after the NDIS miniport wraps them; the USB type is not exposed here.
const (
ifTypeEthernetCSMACD uint32 = 6 // 802.3 wired Ethernet (also used for GbE, 10GbE, USB dongles)
ifTypeIEEE80211 uint32 = 71 // 802.11 WiFi
ifTypeIEEE8023ADLag uint32 = 161 // 802.3ad link aggregation / NIC teaming
ifTypeIEEE80216WMAN uint32 = 237 // WiMAX fixed wireless
ifTypeWWANPP uint32 = 243 // mobile broadband — GSM/LTE/5G
ifTypeWWANPP2 uint32 = 244 // mobile broadband — CDMA
)
// selectPhysicalDefaultInterfaces calls GetAdaptersAddresses once with
// AF_UNSPEC to enumerate all adapters for both IP families in a single kernel
// call. The gateway list (FirstGatewayAddress) contains entries for all
// families; each entry's SocketAddress family field distinguishes IPv4 from
// IPv6. Both Ipv4Metric/IfIndex and Ipv6Metric/Ipv6IfIndex are populated in
// a single AF_UNSPEC response, so no second call is needed.
//
// Physical detection: Windows reports the adapter type via IfType. VPN and
// tunnel drivers always register as IF_TYPE_TUNNEL (131), IF_TYPE_PPP (23),
// IF_TYPE_OTHER (1), or similar non-physical types — never as Ethernet or
// WiFi — so this filter is reliable against any VPN software.
func selectPhysicalDefaultInterfaces() (*net.Interface, *net.Interface, error) {
adapters, err := getAdapterAddresses()
if err != nil {
return nil, nil, err
}
var ipv4Iface, ipv6Iface *net.Interface
var bestV4Metric, bestV6Metric uint32
for a := adapters; a != nil; a = a.Next {
if !isPhysicalIfType(a.IfType) {
continue
}
// Walk the gateway list once and record which families have a gateway.
hasV4Gateway, hasV6Gateway := false, false
for gw := a.FirstGatewayAddress; gw != nil; gw = gw.Next {
switch gw.Address.Sockaddr.Addr.Family {
case windows.AF_INET:
hasV4Gateway = true
case windows.AF_INET6:
hasV6Gateway = true
}
if hasV4Gateway && hasV6Gateway {
break
}
}
// IPv4 candidate: needs a gateway, a valid index, and a routable address.
if hasV4Gateway && a.IfIndex != 0 {
if iface, err := net.InterfaceByIndex(int(a.IfIndex)); err == nil && hasRoutableIPv4(iface) {
if ipv4Iface == nil || a.Ipv4Metric < bestV4Metric {
ipv4Iface = iface
bestV4Metric = a.Ipv4Metric
}
}
}
// IPv6 candidate: needs a gateway, a valid index, and a routable address.
if hasV6Gateway && a.Ipv6IfIndex != 0 {
if iface, err := net.InterfaceByIndex(int(a.Ipv6IfIndex)); err == nil && hasRoutableIPv6(iface) {
if ipv6Iface == nil || a.Ipv6Metric < bestV6Metric {
ipv6Iface = iface
bestV6Metric = a.Ipv6Metric
}
}
}
}
return ipv4Iface, ipv6Iface, nil
}
// isPhysicalIfType reports whether the Windows interface type corresponds to
// real hardware. VPN and tunnel adapters always use non-physical type values.
func isPhysicalIfType(ifType uint32) bool {
switch ifType {
case ifTypeEthernetCSMACD, ifTypeIEEE80211, ifTypeIEEE8023ADLag,
ifTypeIEEE80216WMAN, ifTypeWWANPP, ifTypeWWANPP2:
return true
}
return false
}
// getAdapterAddresses calls GetAdaptersAddresses with AF_UNSPEC and
// GAA_FLAG_INCLUDE_GATEWAYS, returning adapters for all address families in
// one kernel call. It retries with an enlarged buffer if the OS signals that
// the initial 15 KB estimate was too small.
func getAdapterAddresses() (*windows.IpAdapterAddresses, error) {
// 15 KB covers the vast majority of machines (typically < 2 KB per adapter).
size := uint32(15000)
for {
buf := make([]byte, size)
head := (*windows.IpAdapterAddresses)(unsafe.Pointer(&buf[0]))
err := windows.GetAdaptersAddresses(
windows.AF_UNSPEC,
windows.GAA_FLAG_INCLUDE_GATEWAYS,
0,
head,
&size,
)
if err == windows.ERROR_BUFFER_OVERFLOW {
// size has been updated to the required value; retry.
continue
}
if err != nil {
return nil, fmt.Errorf("GetAdaptersAddresses: %w", err)
}
return head, nil
}
}
+4
View File
@@ -226,6 +226,10 @@ func convertConnection(conn *network.Connection) (*Conn, error) {
c.ExitNode = &exitNode
}
if conn.SplitTunContext != nil {
extraData["split_tun"] = conn.SplitTunContext
}
if conn.DNSContext != nil {
extraData["dns"] = conn.DNSContext
}
+1
View File
@@ -167,6 +167,7 @@ func AddNetworkDebugData(di *debug.Info, profile, where string) {
switch conn.Verdict { //nolint:exhaustive
case VerdictAccept,
VerdictRerouteToNameserver,
VerdictRerouteToSplitTun,
VerdictRerouteToTunnel:
accepted++
+14 -1
View File
@@ -56,6 +56,15 @@ type ProcessContext struct {
Source string
}
// SplitTunContext holds additional information about the split tunnel
// that a connection is routed through.
type SplitTunContext struct {
// Interface is the name of the network interface the connection is bound to.
Interface string
// IP is the IP address used to bind the connection to the interface.
IP net.IP
}
// ConnectionType is a type of connection.
type ConnectionType int8
@@ -170,6 +179,10 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
GetExitNodeID() string
StopTunnel() error
}
// SplitTunContext holds additional information about the split tunnel
// that this connection is routed through. It is set when the connection
// verdict is VerdictRerouteToSplitTun and the interface has been resolved.
SplitTunContext *SplitTunContext
// HistoryEnabled is set to true when the connection should be persisted
// in the history database.
@@ -795,7 +808,7 @@ func (conn *Connection) Save() {
// nolint:exhaustive
switch conn.Verdict {
case VerdictAccept, VerdictRerouteToNameserver:
case VerdictAccept, VerdictRerouteToNameserver, VerdictRerouteToSplitTun:
conn.ConnectionEstablished = true
case VerdictRerouteToTunnel:
// this is already handled when the connection tunnel has been
+3 -7
View File
@@ -216,10 +216,10 @@ func (conn *Connection) ReplyWithDNS(ctx context.Context, request *dns.Msg) *dns
return nil // Do not respond to request.
case VerdictFailed:
return nsutil.BlockIP().ReplyWithDNS(ctx, request)
case VerdictUndecided, VerdictUndeterminable,
VerdictAccept, VerdictRerouteToNameserver, VerdictRerouteToTunnel:
fallthrough
default:
// ReplyWithDNS is called when a DNS response to a DNS message is
// crafted because the request is either denied or blocked.
// So, other verdicts are not expected here.
reply := nsutil.ServerFailure().ReplyWithDNS(ctx, request)
nsutil.AddMessagesToReply(ctx, reply, log.ErrorLevel, "INTERNAL ERROR: incorrect use of Connection DNS Responder")
return reply
@@ -233,10 +233,6 @@ func (conn *Connection) GetExtraRRs(ctx context.Context, request *dns.Msg) []dns
switch conn.Verdict {
case VerdictFailed:
level = log.ErrorLevel
case VerdictUndecided, VerdictUndeterminable,
VerdictAccept, VerdictBlock, VerdictDrop,
VerdictRerouteToNameserver, VerdictRerouteToTunnel:
fallthrough
default:
level = log.InfoLevel
}
+1 -1
View File
@@ -145,7 +145,7 @@ func (conn *Connection) addToMetrics() {
blockedOutConnCounter.Inc()
conn.addedToMetrics = true
return
case VerdictAccept, VerdictRerouteToTunnel:
case VerdictAccept, VerdictRerouteToTunnel, VerdictRerouteToSplitTun:
// Continue to next section.
default:
// Connection is not counted.
+4
View File
@@ -74,4 +74,8 @@ func (pkt *InfoPacket) RerouteToTunnel() error {
return ErrInfoOnlyPacket
}
func (pkt *InfoPacket) RerouteToSplitTun() error {
return ErrInfoOnlyPacket
}
var _ Packet = &InfoPacket{}
+1
View File
@@ -231,6 +231,7 @@ type Packet interface {
PermanentDrop() error
RerouteToNameserver() error
RerouteToTunnel() error
RerouteToSplitTun() error
FastTrackedByIntegration() bool
InfoOnly() bool
ExpectInfo() bool
+5
View File
@@ -15,6 +15,7 @@ const (
VerdictRerouteToNameserver Verdict = 5
VerdictRerouteToTunnel Verdict = 6
VerdictFailed Verdict = 7
VerdictRerouteToSplitTun Verdict = 8
)
func (v Verdict) String() string {
@@ -33,6 +34,8 @@ func (v Verdict) String() string {
return "RerouteToNameserver"
case VerdictRerouteToTunnel:
return "RerouteToTunnel"
case VerdictRerouteToSplitTun:
return "RerouteToSplitTun"
case VerdictFailed:
return "Failed"
default:
@@ -57,6 +60,8 @@ func (v Verdict) Verb() string {
return "redirected to nameserver"
case VerdictRerouteToTunnel:
return "tunneled"
case VerdictRerouteToSplitTun:
return "split tunneled"
case VerdictFailed:
return "failed"
default:
+8
View File
@@ -17,6 +17,7 @@ var (
cfgDefaultAction uint8
cfgEndpoints endpoints.Endpoints
cfgServiceEndpoints endpoints.Endpoints
cfgSplitTunUsagePolicy endpoints.Endpoints
cfgSPNUsagePolicy endpoints.Endpoints
cfgSPNTransitHubPolicy endpoints.Endpoints
cfgSPNExitHubPolicy endpoints.Endpoints
@@ -74,6 +75,13 @@ func updateGlobalConfigProfile(_ context.Context) error {
lastErr = err
}
list = cfgOptionSplitTunUsagePolicy()
cfgSplitTunUsagePolicy, err = endpoints.ParseEndpoints(list)
if err != nil {
// TODO: module error?
lastErr = err
}
list = cfgOptionSPNUsagePolicy()
cfgSPNUsagePolicy, err = endpoints.ParseEndpoints(list)
if err != nil {
+110 -1
View File
@@ -1,6 +1,7 @@
package profile
import (
"errors"
"strings"
"github.com/safing/portmaster/base/config"
@@ -141,6 +142,19 @@ var (
cfgOptionExitHubPolicyOrder = 147
// Setting "DNS Exit Node Rules" at order 148.
// Split Tunnel.
CfgOptionSplitTunUseKey = "splittun/use"
cfgOptionSplitTunUse config.BoolOption
cfgOptionSplitTunUseOrder = 212
CfgOptionSplitTunInterfaceKey = "splittun/networkInterface"
cfgOptionSplitTunInterface config.StringOption
cfgOptionSplitTunInterfaceOrder = 214
CfgOptionSplitTunUsagePolicyKey = "splittun/usagePolicy"
cfgOptionSplitTunUsagePolicy config.StringArrayOption
cfgOptionSplitTunUsagePolicyOrder = 216
)
var (
@@ -698,7 +712,7 @@ Please note that DNS bypass attempts might be additionally blocked in the System
err = config.Register(&config.Option{
Name: "SPN Rules",
Key: CfgOptionSPNUsagePolicyKey,
Description: `Customize which websites should or should not be routed through the SPN. Only active if "Use SPN" is enabled.`,
Description: `Customize rules which connections should or should not be routed through the SPN. Only active if "Use SPN" is enabled.`,
Help: rulesHelp,
Sensitive: true,
OptType: config.OptTypeStringArray,
@@ -819,5 +833,100 @@ By default, the Portmaster tries to choose the node closest to the destination a
cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(CfgOptionRoutingAlgorithmKey, DefaultRoutingProfileID)
cfgStringOptions[CfgOptionRoutingAlgorithmKey] = cfgOptionRoutingAlgorithm
//
// Split Tunnel
//
// Split Tunnel: Use
err = config.Register(&config.Option{
Name: "Use Split Tunnel",
Key: CfgOptionSplitTunUseKey,
Description: `Route specific traffic through a different network interface, bypassing default system routing (useful for avoiding VPNs for certain apps).
When you enable this and the Network Interface option is empty, Portmaster will try to route your traffic through the default physical network interface.
Important: SPN takes precedence over Split Tunnel. To use Split Tunnel with SPN, configure SPN on a per-app basis or define exceptions that allow Split Tunnel to take effect.`,
OptType: config.OptTypeBool,
DefaultValue: false,
Annotations: config.Annotations{
config.SettablePerAppAnnotation: true,
config.DisplayOrderAnnotation: cfgOptionSplitTunUseOrder,
config.CategoryAnnotation: "General",
},
})
if err != nil {
return err
}
cfgOptionSplitTunUse = config.Concurrent.GetAsBool(CfgOptionSplitTunUseKey, false)
cfgBoolOptions[CfgOptionSplitTunUseKey] = cfgOptionSplitTunUse
// Split Tunnel: Network Interface
err = config.Register(&config.Option{
Name: "Network Interface",
Key: CfgOptionSplitTunInterfaceKey,
Description: `Specify the network interface to route Split Tunnel traffic through. You can define it by:
- Interface name: "Ethernet", "Wi-Fi", "wlan0", etc.
- Interface IP address: "192.168.1.1", "10.0.0.1", etc.
- Interface MAC address: "00:1A:2B:3C:4D:5E", "01:23:45:67:89:AB", etc.
Leave empty to let Portmaster detect the physical network interface and ignore virtual VPN interfaces. This helps bypass VPN tunnels. For better reliability, you can specify the interface manually if empty value does not work as expected.
Important: The connection will be dropped if the network interface cannot be detected or becomes unavailable.
Important: SPN takes precedence over Split Tunnel. To use Split Tunnel with SPN, configure SPN on a per-app basis or define exceptions that allow Split Tunnel to take effect.`,
Sensitive: true,
OptType: config.OptTypeString,
DefaultValue: "",
Annotations: config.Annotations{
config.SettablePerAppAnnotation: true,
config.DisplayOrderAnnotation: cfgOptionSplitTunInterfaceOrder,
config.CategoryAnnotation: "General",
},
ValidationFunc: func(value interface{}) error {
if s, ok := value.(string); ok && s != "" && strings.TrimSpace(s) == "" {
return errors.New("network interface cannot contain only whitespace characters")
}
return nil
},
})
if err != nil {
return err
}
cfgOptionSplitTunInterface = config.Concurrent.GetAsString(CfgOptionSplitTunInterfaceKey, "")
cfgStringOptions[CfgOptionSplitTunInterfaceKey] = cfgOptionSplitTunInterface
// Split Tunnel: Rules
splitTunRulesVerdictNames := map[string]string{
"-": "Exclude", // Default.
"+": "Allow",
}
err = config.Register(&config.Option{
Name: "Split Tunnel Rules",
Key: CfgOptionSplitTunUsagePolicyKey,
Description: `Customize rules which connections should or should not be routed through the Split Tunnel. Only active if "Use Split Tunnel" is enabled.
Important: SPN takes precedence over Split Tunnel. To use Split Tunnel with SPN, configure SPN on a per-app basis or define exceptions that allow Split Tunnel to take effect.`,
Help: rulesHelp,
Sensitive: true,
OptType: config.OptTypeStringArray,
DefaultValue: []string{},
Annotations: config.Annotations{
config.SettablePerAppAnnotation: true,
config.StackableAnnotation: true,
config.CategoryAnnotation: "General",
config.DisplayOrderAnnotation: cfgOptionSplitTunUsagePolicyOrder,
config.DisplayHintAnnotation: endpoints.DisplayHintEndpointList,
endpoints.EndpointListVerdictNamesAnnotation: splitTunRulesVerdictNames,
},
ValidationRegex: endpoints.ListEntryValidationRegex,
ValidationFunc: endpoints.ValidateEndpointListConfigOption,
})
if err != nil {
return err
}
cfgOptionSplitTunUsagePolicy = config.Concurrent.GetAsStringArray(CfgOptionSplitTunUsagePolicyKey, []string{})
cfgStringArrayOptions[CfgOptionSplitTunUsagePolicyKey] = cfgOptionSplitTunUsagePolicy
return nil
}
+26
View File
@@ -50,6 +50,8 @@ type LayeredProfile struct {
SPNRoutingAlgorithm config.StringOption `json:"-"`
EnableHistory config.BoolOption `json:"-"`
KeepHistory config.IntOption `json:"-"`
UseSplitTun config.BoolOption `json:"-"`
SplitTunInterface config.StringOption `json:"-"`
}
// NewLayeredProfile returns a new layered profile based on the given local profile.
@@ -113,6 +115,14 @@ func NewLayeredProfile(localProfile *Profile) *LayeredProfile {
CfgOptionDomainHeuristicsKey,
cfgOptionDomainHeuristics,
)
lp.UseSplitTun = lp.wrapBoolOption(
CfgOptionSplitTunUseKey,
cfgOptionSplitTunUse,
)
lp.SplitTunInterface = lp.wrapStringOption(
CfgOptionSplitTunInterfaceKey,
cfgOptionSplitTunInterface,
)
lp.UseSPN = lp.wrapBoolOption(
CfgOptionUseSPNKey,
cfgOptionUseSPN,
@@ -349,6 +359,22 @@ func (lp *LayeredProfile) MatchServiceEndpoint(ctx context.Context, entity *inte
return cfgServiceEndpoints.Match(ctx, entity)
}
// MatchSplitTunUsagePolicy checks if the given endpoint matches an entry in any Split Tunnel usage policy in any of the profiles. This functions requires the layered profile to be read locked.
func (lp *LayeredProfile) MatchSplitTunUsagePolicy(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
for _, layer := range lp.layers {
if layer.splitTunUsagePolicy.IsSet() {
result, reason := layer.splitTunUsagePolicy.Match(ctx, entity)
if endpoints.IsDecision(result) {
return result, reason
}
}
}
cfgLock.RLock()
defer cfgLock.RUnlock()
return cfgSplitTunUsagePolicy.Match(ctx, entity)
}
// MatchSPNUsagePolicy checks if the given endpoint matches an entry in any of the profiles. This functions requires the layered profile to be read locked.
func (lp *LayeredProfile) MatchSPNUsagePolicy(ctx context.Context, entity *intel.Entity) (endpoints.EPResult, endpoints.Reason) {
for _, layer := range lp.layers {
+10
View File
@@ -124,6 +124,7 @@ type Profile struct { //nolint:maligned // not worth the effort
spnUsagePolicy endpoints.Endpoints
spnTransitHubPolicy endpoints.Endpoints
spnExitHubPolicy endpoints.Endpoints
splitTunUsagePolicy endpoints.Endpoints
// Lifecycle Management
outdated *abool.AtomicBool
@@ -203,6 +204,15 @@ func (profile *Profile) parseConfig() error {
}
}
list, ok = profile.configPerspective.GetAsStringArray(CfgOptionSplitTunUsagePolicyKey)
profile.splitTunUsagePolicy = nil
if ok {
profile.splitTunUsagePolicy, err = endpoints.ParseEndpoints(list)
if err != nil {
lastErr = err
}
}
list, ok = profile.configPerspective.GetAsStringArray(CfgOptionSPNUsagePolicyKey)
profile.spnUsagePolicy = nil
if ok {
+32
View File
@@ -0,0 +1,32 @@
package splittun
import (
"github.com/safing/portmaster/base/config"
)
var (
CfgOptionSplitTunEnableKey = "splittun/enable"
cfgOptionSplitTunEnable config.BoolOption
cfgOptionSplitTunEnableOrder = 210
)
func prepConfig() error {
// Register split tunnel module setting.
err := config.Register(&config.Option{
Name: "Split Tunnel Module",
Key: CfgOptionSplitTunEnableKey,
Description: "Start the Split Tunnel module. If turned off, the Split Tunnel functionality is fully disabled on this device.",
OptType: config.OptTypeBool,
DefaultValue: false,
Annotations: config.Annotations{
config.DisplayOrderAnnotation: cfgOptionSplitTunEnableOrder,
config.CategoryAnnotation: "General",
},
})
if err != nil {
return err
}
cfgOptionSplitTunEnable = config.Concurrent.GetAsBool(CfgOptionSplitTunEnableKey, false)
return nil
}
+107
View File
@@ -0,0 +1,107 @@
package splittun
import (
"errors"
"sync/atomic"
"github.com/safing/portmaster/base/config"
"github.com/safing/portmaster/service/mgr"
)
const SplitTunPort = 719
type SplitTunModule struct {
mgr *mgr.Manager
instance instance
}
var (
module *SplitTunModule
shimLoaded atomic.Bool
ready atomic.Bool // ready indicates whether the module is fully initialized and ready to handle requests.
)
func IsReady() bool {
return ready.Load()
}
func New(instance instance) (*SplitTunModule, error) {
if !shimLoaded.CompareAndSwap(false, true) {
return nil, errors.New("only one instance allowed")
}
m := mgr.New("SplitTunModule")
module = &SplitTunModule{
mgr: m,
instance: instance,
}
if err := prep(); err != nil {
return nil, err
}
return module, nil
}
func prep() error {
return prepConfig()
}
func (s *SplitTunModule) Manager() *mgr.Manager {
return s.mgr
}
func (s *SplitTunModule) Start() error {
module.instance.Config().EventConfigChange.AddCallback("splittun enable check", func(wc *mgr.WorkerCtx, t struct{}) (bool, error) {
if cfgOptionSplitTunEnable() {
s.enable()
} else {
s.disable()
}
return false, nil
})
if cfgOptionSplitTunEnable() {
s.enable()
}
return nil
}
func (s *SplitTunModule) Stop() error {
return s.disable()
}
func (s *SplitTunModule) enable() error {
if !ready.CompareAndSwap(false, true) {
return nil // already enabled
}
s.mgr.Info("splittun: enabling Split Tunnel functionality")
err := startProxies(s.mgr)
if err != nil {
s.mgr.Error("splittun: failed to start Split Tunnel proxies: ", err)
ready.Store(false)
}
return err
}
func (s *SplitTunModule) disable() error {
if !ready.CompareAndSwap(true, false) {
return nil // already disabled
}
s.mgr.Info("splittun: disabling Split Tunnel functionality")
clearPendingRequests()
err := stopProxies()
if err != nil {
s.mgr.Error("splittun: failed to stop Split Tunnel proxies: ", err)
}
return err
}
// INSTANCE
type instance interface {
Config() *config.Config
}
+164
View File
@@ -0,0 +1,164 @@
package splittun
import (
"context"
"fmt"
"net"
"sync"
"github.com/safing/portmaster/service/mgr"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/splittun/proxy"
)
var (
proxiesLocker sync.RWMutex
manager *mgr.Manager
tcp4Proxy *proxy.TCPProxy
tcp6Proxy *proxy.TCPProxy
udp4Proxy *proxy.UDPProxy
udp6Proxy *proxy.UDPProxy
)
type proxiedEgressFinder interface {
HasProxiedEgressConnection(destIP net.IP, destPort uint16) bool
}
func IsProxiedConnectionInfo(connInfo *network.Connection) bool {
if connInfo == nil || connInfo.Entity == nil || connInfo.LocalIP == nil || connInfo.Entity.IP == nil {
return false
}
proxiesLocker.RLock()
var finder proxiedEgressFinder
switch connInfo.IPProtocol {
case packet.TCP:
switch connInfo.IPVersion {
case packet.IPv4:
finder = tcp4Proxy
case packet.IPv6:
finder = tcp6Proxy
}
case packet.UDP:
switch connInfo.IPVersion {
case packet.IPv4:
finder = udp4Proxy
case packet.IPv6:
finder = udp6Proxy
}
}
if finder == nil {
proxiesLocker.RUnlock()
return false
}
isProxied := finder.HasProxiedEgressConnection(connInfo.Entity.IP, connInfo.Entity.Port)
proxiesLocker.RUnlock()
return isProxied
}
func startProxies(mgr *mgr.Manager) error {
var (
tcp4 *proxy.TCPProxy
tcp6 *proxy.TCPProxy
udp4 *proxy.UDPProxy
udp6 *proxy.UDPProxy
err error
)
_ = stopProxies()
// Ensure any partially-started proxies are shut down if we return an error.
var startErr error
defer func() {
if startErr != nil {
ctx := mgr.Ctx()
if tcp4 != nil {
tcp4.Shutdown(ctx)
}
if udp4 != nil {
udp4.Shutdown(ctx)
}
if tcp6 != nil {
tcp6.Shutdown(ctx)
}
if udp6 != nil {
udp6.Shutdown(ctx)
}
}
}()
tcp4, err = proxy.NewTCPProxy(fmt.Sprintf("0.0.0.0:%d", SplitTunPort), "tcp4", proxyDecider, mgr, "TCP-IPv4-proxy")
if err != nil {
startErr = fmt.Errorf("failed to start TCPv4 proxy: %w", err)
return startErr
}
udp4, err = proxy.NewUDPProxy(fmt.Sprintf("0.0.0.0:%d", SplitTunPort), "udp4", proxyDecider, mgr, "UDP-IPv4-proxy")
if err != nil {
startErr = fmt.Errorf("failed to start UDPv4 proxy: %w", err)
return startErr
}
if netenv.IPv6Enabled() {
tcp6, err = proxy.NewTCPProxy(fmt.Sprintf("[::]:%d", SplitTunPort), "tcp6", proxyDecider, mgr, "TCP-IPv6-proxy")
if err != nil {
startErr = fmt.Errorf("failed to start TCPv6 proxy: %w", err)
return startErr
}
udp6, err = proxy.NewUDPProxy(fmt.Sprintf("[::]:%d", SplitTunPort), "udp6", proxyDecider, mgr, "UDP-IPv6-proxy")
if err != nil {
startErr = fmt.Errorf("failed to start UDPv6 proxy: %w", err)
return startErr
}
}
proxiesLocker.Lock()
manager = mgr
tcp4Proxy = tcp4
tcp6Proxy = tcp6
udp4Proxy = udp4
udp6Proxy = udp6
proxiesLocker.Unlock()
return nil
}
func stopProxies() error {
proxiesLocker.Lock()
mgr := manager
tcp4 := tcp4Proxy
tcp6 := tcp6Proxy
udp4 := udp4Proxy
udp6 := udp6Proxy
tcp4Proxy = nil
tcp6Proxy = nil
udp4Proxy = nil
udp6Proxy = nil
proxiesLocker.Unlock()
var ctx context.Context
if mgr != nil {
ctx = mgr.Ctx()
} else {
ctx = context.Background()
}
if tcp4 != nil {
tcp4.Shutdown(ctx)
}
if tcp6 != nil {
tcp6.Shutdown(ctx)
}
if udp4 != nil {
udp4.Shutdown(ctx)
}
if udp6 != nil {
udp6.Shutdown(ctx)
}
return nil
}
+289
View File
@@ -0,0 +1,289 @@
# proxy
Internal Layer-4 TCP and UDP proxy package used by the split-tunnelling
subsystem. Provides injected routing decisions, session tracking, and graceful
shutdown.
---
## Overview
| Feature | TCP | UDP |
|---------|-----|-----|
| Routing via `DeciderFunc` | ✓ | ✓ |
| Optional source-address binding | ✓ | ✓ |
| Interface binding via `SO_BINDTODEVICE` (Linux) | ✓ | ✓ |
| Session tracking + metrics | ✓ | ✓ |
| Pooled copy buffers | ✓ | ✓ |
| Graceful shutdown | ✓ | ✓ |
| Max sessions limit | ✓ | ✓ |
| Read/write deadlines | ✓ | ✓ |
| Idle eviction (cleanup loop) | — | ✓ |
| Bidirectional, half-close | ✓ | n/a |
---
## API
### Types
```go
// LocalBinding carries the local-side binding parameters for an outbound proxy
// connection. Both fields are optional and may be set independently.
type LocalBinding struct {
// IP is the local source address to bind the outgoing socket to.
// If nil, the OS selects an appropriate source address.
IP net.IP
// Interface is the name of the network interface (e.g. "eth0") to bind
// the outgoing socket to via SO_BINDTODEVICE (Linux only).
// An empty string disables interface-level binding.
Interface string
}
// DeciderFunc is called once per new session to determine the upstream
// destination and optional local binding parameters for the outgoing socket.
//
// local is the proxy's listen address; peer is the connecting client's address.
//
// It returns:
// - remoteIP: required upstream IP address.
// - remotePort: required upstream port.
// - binding: optional local binding; nil lets the OS choose freely.
// Set binding.IP to pin the source address, binding.Interface
// to restrict the socket to a specific network device (Linux).
// - extraInfo: optional caller-defined value attached to the session's ConnContext.
// - err: non-nil rejects the session without dialling upstream.
type DeciderFunc func(local net.Addr, peer net.Addr) (
remoteIP net.IP,
remotePort uint16,
binding *LocalBinding,
extraInfo any,
err error,
)
// Logger is the minimal interface accepted by both proxies.
// Pass nil to suppress all log output.
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
}
// ConnContext holds observable state for one proxy session.
// Counters are updated atomically and safe for concurrent reads.
type ConnContext struct {
BytesIn atomic.Uint64 // bytes forwarded upstream → client
BytesOut atomic.Uint64 // bytes forwarded client → upstream
PacketsIn atomic.Uint64 // UDP datagrams upstream → client
PacketsOut atomic.Uint64 // UDP datagrams client → upstream
}
func (c *ConnContext) ID() uint64
func (c *ConnContext) PeerAddr() net.Addr
func (c *ConnContext) DestIP() net.IP
func (c *ConnContext) DestPort() uint16
func (c *ConnContext) CreatedAt() time.Time
func (c *ConnContext) LastSeen() time.Time
func (c *ConnContext) ExtraInfo() any
func (c *ConnContext) Close() // cancels the session
```
### Constructors
```go
// TCP — uses DefaultConfig
func NewTCPProxy(listenAddr string, network string, decider DeciderFunc, logger Logger, logPrefix string) (*TCPProxy, error)
// TCP — custom configuration
func NewTCPProxyWithConfig(listenAddr string, network string, decider DeciderFunc, logger Logger, cfg Config, logPrefix string) (*TCPProxy, error)
// UDP — uses DefaultConfig
func NewUDPProxy(listenAddr string, network string, decider DeciderFunc, logger Logger, logPrefix string) (*UDPProxy, error)
// UDP — custom configuration
func NewUDPProxyWithConfig(listenAddr string, network string, decider DeciderFunc, logger Logger, cfg Config, logPrefix string) (*UDPProxy, error)
```
Both constructors bind the socket and start background goroutines immediately.
They return an error if binding fails or if `decider` is nil.
### Address
```go
func (p *TCPProxy) Addr() net.Addr
func (p *UDPProxy) Addr() net.Addr
```
Returns the address the proxy is currently listening on.
### Configuration
```go
type Config struct {
// MaxSessions is the maximum number of concurrent sessions (0 = unlimited).
// Default: 2048.
MaxSessions int
// ReadTimeout closes a session after this duration with no bytes received
// from the source. The deadline is rolled forward on every successful
// read, so only truly silent sessions are evicted.
// Default: 5 min.
ReadTimeout time.Duration
// WriteTimeout is the maximum time allowed for a single write to complete.
// Guards against a stalled destination holding a goroutine open.
// Default: 30 s.
WriteTimeout time.Duration
// BufferSize is the size of copy buffers used by the TCP pipe (bytes).
// UDP always uses 64 KiB buffers regardless of this setting.
// Default: 32 KiB.
BufferSize int
// DialTimeout is the maximum time the TCP proxy waits when dialling the
// upstream destination. Default: 10 s.
DialTimeout time.Duration
}
func DefaultConfig() Config
```
### Shutdown
```go
func (p *TCPProxy) Shutdown(ctx context.Context) error
func (p *UDPProxy) Shutdown(ctx context.Context) error
```
Closes the listen socket, cancels all active sessions, and waits for
goroutines to drain. If `ctx` expires first, the method returns
`ctx.Err()` but goroutines are still cleaning up (they are not leaked).
### Session lookup
```go
// Returns all active sessions whose upstream destination matches destIP:destPort.
// Returns nil if none exist.
func (p *TCPProxy) FindProxiedEgressConnection(destIP net.IP, destPort uint16) []*ConnContext
func (p *UDPProxy) FindProxiedEgressConnection(destIP net.IP, destPort uint16) []*ConnContext
```
### Metrics
```go
type Metrics struct {
ActiveSessions uint64
TotalCreated uint64
TotalClosed uint64
}
func (p *TCPProxy) Metrics() Metrics
func (p *UDPProxy) Metrics() Metrics
```
---
## Usage examples
### Transparent TCP proxy (always route to a fixed backend)
```go
decider := func(local, peer net.Addr) (net.IP, uint16, *proxy.LocalBinding, any, error) {
return net.ParseIP("192.168.1.10"), 8080, nil, nil, nil
}
p, err := proxy.NewTCPProxy(":8080", "tcp4", decider, nil, "tcp proxy IPv4")
if err != nil {
log.Fatal(err)
}
// Graceful shutdown on SIGTERM with a 10-second drain window.
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM)
<-sig
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
p.Shutdown(ctx)
```
### Per-client routing with source-address and interface binding (split tunnelling)
```go
decider := func(local, peer net.Addr) (net.IP, uint16, *proxy.LocalBinding, any, error) {
host, _, _ := net.SplitHostPort(peer.String())
if isTunnelledIP(host) {
// Route through the physical interface, binding the source address and
// restricting the socket to that device so traffic bypasses the VPN.
return directGatewayIP, 443, &proxy.LocalBinding{
IP: net.ParseIP("192.168.1.5"), // physical interface address
Interface: "eth0", // Linux: SO_BINDTODEVICE
}, nil, nil
}
return vpnGatewayIP, 443, nil, nil, nil
}
p, err := proxy.NewTCPProxy(":443", "tcp4", decider, myLogger, "tcp proxy IPv4")
```
### UDP proxy with custom timeouts
```go
cfg := proxy.DefaultConfig()
cfg.ReadTimeout = 30 * time.Second
cfg.MaxSessions = 1024
p, err := proxy.NewUDPProxyWithConfig(":5353", "udp4", decider, myLogger, cfg, "udp proxy IPv4")
```
---
## Running tests and benchmarks
```sh
# Unit tests
go test ./...
# Race detector
go test -race ./...
# Benchmarks with allocation reporting
go test -run='^$' -bench=Benchmark -benchmem # All benchmarks
go test -run='^$' -bench=BenchmarkTCP -benchmem # TCP only
go test -run='^$' -bench=BenchmarkUDP -benchmem # UDP only
```
---
## Design notes
* **Pooled buffers** — TCP pipes use a `sync.Pool` of 32 KiB `[]byte` slices;
the UDP path uses a separate pool of 64 KiB slices (maximum UDP payload).
Both avoid per-transfer heap allocations in steady state.
* **Goroutine budget** — the TCP proxy spawns four goroutines per session: one
session handler, two bidirectional copy goroutines (one per direction), and
one watchdog; the UDP proxy spawns one goroutine per session (upstream
reader) plus two shared goroutines (inbound read loop and idle cleanup loop).
All goroutines are tracked via a `sync.WaitGroup`.
* **Half-close** — when one TCP peer closes its write side, the proxy attempts
`CloseWrite` on the upstream, enabling proper FIN propagation.
* **NAT session table** — UDP sessions are keyed by the client's `"ip:port"`
string. A double-checked locking pattern prevents duplicate sessions under
burst traffic.
* **UDP write deadline sharing** — all upstream-to-client goroutines write on
the same shared listen socket. Each goroutine sets a rolling write deadline
immediately before its own write, so concurrent sessions can shift each
other's deadline by at most `WriteTimeout`. This is an accepted trade-off of
the single-socket UDP design.
* **Context propagation** — the proxy's top-level `context.Context` is the
parent of every session context, so a single `Shutdown` call cascades to
all live sessions.
* **Interface binding (Linux)** — when `LocalBinding.Interface` is non-empty,
`SO_BINDTODEVICE` is set on the outgoing socket via `net.Dialer.Control`
before `connect(2)`. This forces the kernel to route the connection through
the named device regardless of the routing table, which is required for
split-tunnelling when a default VPN route would otherwise capture the traffic.
On non-Linux platforms the field is ignored (no-op).
+156
View File
@@ -0,0 +1,156 @@
package proxy
import (
"bytes"
"context"
"io"
"net"
"testing"
"time"
)
// ─── TCP benchmarks ───────────────────────────────────────────────────────────
// BenchmarkTCPProxy_Throughput measures round-trip throughput through the TCP
// proxy using a local echo server. Run with -benchmem to observe allocations.
func BenchmarkTCPProxy_Throughput(b *testing.B) {
echoAddr, stopEcho := startTCPEchoServer(b)
defer stopEcho()
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
b.Fatalf("NewTCPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
conn, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
b.Fatalf("dial: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Minute)) //nolint:errcheck
const msgSize = 32 * 1024
payload := bytes.Repeat([]byte("B"), msgSize)
recv := make([]byte, msgSize)
b.SetBytes(int64(msgSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := conn.Write(payload); err != nil {
b.Fatalf("write: %v", err)
}
if _, err := io.ReadFull(conn, recv); err != nil {
b.Fatalf("read: %v", err)
}
}
}
// BenchmarkTCPProxy_NewSession measures the overhead of establishing and
// tearing down a new TCP session through the proxy.
func BenchmarkTCPProxy_NewSession(b *testing.B) {
echoAddr, stopEcho := startTCPEchoServer(b)
defer stopEcho()
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
b.Fatalf("NewTCPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
b.Fatalf("dial: %v", err)
}
conn.Close()
// Allow the session to be removed before next iteration.
for proxy.Metrics().TotalClosed < uint64(i+1) {
// spin — in benchmarks this is acceptable over a sleep
}
}
}
// ─── UDP benchmarks ───────────────────────────────────────────────────────────
// BenchmarkUDPProxy_Throughput measures datagrams-per-second through the UDP
// proxy.
func BenchmarkUDPProxy_Throughput(b *testing.B) {
echoAddr, stopEcho := startUDPEchoServer(b)
defer stopEcho()
cfg := DefaultConfig()
cfg.ReadTimeout = 30 * time.Second
proxy, err := NewUDPProxyWithConfig("127.0.0.1:0", "udp4", passThroughDecider(echoAddr), nil, cfg, "")
if err != nil {
b.Fatalf("NewUDPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
clientConn, err := net.DialUDP("udp", nil, proxy.Addr().(*net.UDPAddr))
if err != nil {
b.Fatalf("dial: %v", err)
}
defer clientConn.Close()
clientConn.SetDeadline(time.Now().Add(10 * time.Minute)) //nolint:errcheck
const msgSize = 1024 * (64 - 1)
payload := bytes.Repeat([]byte("U"), msgSize)
recv := make([]byte, 64*1024)
b.SetBytes(int64(msgSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := clientConn.Write(payload); err != nil {
b.Fatalf("write: %v", err)
}
if _, err := clientConn.Read(recv); err != nil {
b.Fatalf("read: %v", err)
}
}
}
// BenchmarkUDPProxy_NewSession measures session-creation cost for the UDP
// proxy: each iteration uses a unique local port so that every packet triggers
// the slow-path decider call and upstream dial.
func BenchmarkUDPProxy_NewSession(b *testing.B) {
echoAddr, stopEcho := startUDPEchoServer(b)
defer stopEcho()
cfg := DefaultConfig()
cfg.ReadTimeout = 100 * time.Millisecond
proxy, err := NewUDPProxyWithConfig("127.0.0.1:0", "udp4", passThroughDecider(echoAddr), nil, cfg, "")
if err != nil {
b.Fatalf("NewUDPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
proxyUDPAddr := proxy.Addr().(*net.UDPAddr)
payload := []byte("ping")
recv := make([]byte, 64)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c, err := net.DialUDP("udp", nil, proxyUDPAddr)
if err != nil {
b.Fatalf("dial: %v", err)
}
c.SetDeadline(time.Now().Add(2 * time.Second)) //nolint:errcheck
if _, err := c.Write(payload); err != nil {
c.Close()
b.Fatalf("write: %v", err)
}
if _, err := c.Read(recv); err != nil {
c.Close()
b.Fatalf("read: %v", err)
}
c.Close()
}
}
// startTCPEchoServer and startUDPEchoServer are defined in proxy_test.go.
+38
View File
@@ -0,0 +1,38 @@
//go:build linux
package proxy
import (
"net"
"syscall"
)
// applyBindToDevice configures d to bind all outgoing connections to the named
// network interface via the SO_BINDTODEVICE socket option. The option is set
// in d.Control, which the net package invokes on the raw file descriptor
// immediately after socket creation and before connect(2), ensuring the kernel
// routes the connection through the specified device regardless of the routing
// table.
//
// If iface is empty, d is left unchanged and no binding is performed.
// d.Control is overwritten; any previously set hook is discarded.
func applyBindToDevice(d *net.Dialer, iface string) {
if iface == "" {
return
}
d.Control = func(network, address string, c syscall.RawConn) error {
var innerErr error
err := c.Control(func(fd uintptr) {
innerErr = syscall.SetsockoptString(
int(fd),
syscall.SOL_SOCKET,
syscall.SO_BINDTODEVICE,
iface,
)
})
if err != nil {
return err
}
return innerErr
}
}
+9
View File
@@ -0,0 +1,9 @@
//go:build !linux
package proxy
import "net"
// applyBindToDevice is a no-op on non-Linux platforms; SO_BINDTODEVICE is a
// Linux-specific socket option and has no equivalent here.
func applyBindToDevice(_ *net.Dialer, _ string) {}
+270
View File
@@ -0,0 +1,270 @@
package proxy
import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// ─── ConnContext ──────────────────────────────────────────────────────────────
// ConnContext holds all observable state for one proxy session.
//
// The counters are updated atomically and are safe for concurrent reads.
type ConnContext struct {
// id is a monotonically increasing session identifier (starts at 1).
id uint64
// peerAddr is the connecting client's address.
peerAddr net.Addr
// destIP is the upstream IP address chosen by DeciderFunc.
destIP net.IP
// destPort is the upstream port chosen by DeciderFunc.
destPort uint16
// createdAt is the wall-clock time the session was established.
createdAt time.Time
// lastSeen stores a UnixNano timestamp updated on every transferred packet/byte.
lastSeen atomic.Int64
// BytesIn counts bytes forwarded from upstream to the client.
BytesIn atomic.Uint64
// BytesOut counts bytes forwarded from the client to upstream.
BytesOut atomic.Uint64
// PacketsIn counts UDP datagrams forwarded from upstream to the client.
PacketsIn atomic.Uint64
// PacketsOut counts UDP datagrams forwarded from the client to upstream.
PacketsOut atomic.Uint64
// extraInfo is an optional user-defined object returned by DeciderFunc.
// It is set once at session creation and never modified.
extraInfo any
// cancel closes the session's context.
cancel func()
}
// newConnContext allocates a ConnContext and initialises lastSeen to now.
// destIP is normalised to 16-byte form (IPv4-in-IPv6) for consistent indexing.
func newConnContext(id uint64, peer net.Addr, destIP net.IP, destPort uint16, cancel func(), extraInfo any) *ConnContext {
now := time.Now()
c := &ConnContext{
id: id,
peerAddr: peer,
destIP: destIP.To16(),
destPort: destPort,
createdAt: now,
cancel: cancel,
extraInfo: extraInfo,
}
c.lastSeen.Store(now.UnixNano())
return c
}
// LastSeen returns the time of the most recently observed packet or byte.
func (c *ConnContext) LastSeen() time.Time {
return time.Unix(0, c.lastSeen.Load())
}
// touch updates lastSeen to the current time.
func (c *ConnContext) touch() {
c.lastSeen.Store(time.Now().UnixNano())
}
// Close cancels the session. Safe to call multiple times.
func (c *ConnContext) Close() {
if c.cancel != nil {
c.cancel()
}
}
// ID returns the session's monotonically increasing identifier (starts at 1).
func (c *ConnContext) ID() uint64 { return c.id }
// PeerAddr returns the connecting client's address.
func (c *ConnContext) PeerAddr() net.Addr { return c.peerAddr }
// DestIP returns the upstream IP address chosen by DeciderFunc.
func (c *ConnContext) DestIP() net.IP { return c.destIP }
// DestPort returns the upstream port chosen by DeciderFunc.
func (c *ConnContext) DestPort() uint16 { return c.destPort }
// CreatedAt returns the wall-clock time the session was established.
func (c *ConnContext) CreatedAt() time.Time { return c.createdAt }
// ExtraInfo returns the optional user-defined object returned by DeciderFunc.
// It is set once at session creation and never modified.
func (c *ConnContext) ExtraInfo() any { return c.extraInfo }
// ─── Metrics ──────────────────────────────────────────────────────────────────
// Metrics is a snapshot of session cache statistics.
type Metrics struct {
ActiveSessions uint64
TotalCreated uint64
TotalClosed uint64
}
func (m Metrics) String() string {
return fmt.Sprintf("active=%d created=%d closed=%d",
m.ActiveSessions, m.TotalCreated, m.TotalClosed)
}
// ─── Session cache ────────────────────────────────────────────────────────────
// destKey is the secondary-index key used to look up sessions by upstream
// destination. Using a fixed-size struct as a map key avoids string allocation
// and gives O(1) hashing.
type destKey struct {
ip [16]byte // IPv4-in-IPv6 form (To16)
port uint16
}
// makeDestKey builds a destKey from a pre-parsed IP and port.
// Returns (key, false) if ip is nil.
func makeDestKey(ip net.IP, port uint16) (destKey, bool) {
ip16 := ip.To16()
if ip16 == nil {
return destKey{}, false
}
var k destKey
copy(k.ip[:], ip16)
k.port = port
return k, true
}
// sessionCache is a concurrent-safe registry of live ConnContexts together
// with aggregate lifetime metrics.
type sessionCache struct {
mu sync.RWMutex
entries map[uint64]*ConnContext
// byDest is a secondary index: destKey → set of ConnContexts.
// It allows FindProxiedEgressConnection to skip iterating all entries.
byDest map[destKey]map[uint64]*ConnContext
totalCreated atomic.Uint64
totalClosed atomic.Uint64
}
func newSessionCache() *sessionCache {
return &sessionCache{
entries: make(map[uint64]*ConnContext, 64),
byDest: make(map[destKey]map[uint64]*ConnContext),
}
}
// add registers a new session.
func (c *sessionCache) add(ctx *ConnContext) {
k, hasKey := makeDestKey(ctx.destIP, ctx.destPort)
c.mu.Lock()
c.entries[ctx.id] = ctx
if hasKey {
inner := c.byDest[k]
if inner == nil {
inner = make(map[uint64]*ConnContext, 1)
c.byDest[k] = inner
}
inner[ctx.id] = ctx
}
c.mu.Unlock()
c.totalCreated.Add(1)
}
// remove unregisters a session. It is idempotent.
func (c *sessionCache) remove(ctx *ConnContext) {
c.mu.Lock()
if _, ok := c.entries[ctx.id]; ok {
delete(c.entries, ctx.id)
if k, hasKey := makeDestKey(ctx.destIP, ctx.destPort); hasKey {
inner := c.byDest[k]
delete(inner, ctx.id)
if len(inner) == 0 {
delete(c.byDest, k)
}
}
c.totalClosed.Add(1)
}
c.mu.Unlock()
}
// findByDest returns all active sessions whose upstream destination matches
// destIP and destPort. Returns nil if no matching session exists.
func (c *sessionCache) findByDest(destIP net.IP, destPort uint16) []*ConnContext {
ip16 := destIP.To16()
if ip16 == nil {
return nil
}
var k destKey
copy(k.ip[:], ip16)
k.port = destPort
c.mu.RLock()
inner := c.byDest[k]
if len(inner) == 0 {
c.mu.RUnlock()
return nil
}
result := make([]*ConnContext, 0, len(inner))
for _, ctx := range inner {
result = append(result, ctx)
}
c.mu.RUnlock()
return result
}
// hasByDest checks if there is an active session whose upstream destination
// matches destIP and destPort. Returns false if no matching session exists.
func (c *sessionCache) hasByDest(destIP net.IP, destPort uint16) bool {
ip16 := destIP.To16()
if ip16 == nil {
return false
}
var k destKey
copy(k.ip[:], ip16)
k.port = destPort
c.mu.RLock()
has := len(c.byDest[k]) > 0
c.mu.RUnlock()
return has
}
// get retrieves a session by ID.
func (c *sessionCache) get(id uint64) (*ConnContext, bool) {
c.mu.RLock()
ctx, ok := c.entries[id]
c.mu.RUnlock()
return ctx, ok
}
// len returns the current number of active sessions.
func (c *sessionCache) len() int {
c.mu.RLock()
n := len(c.entries)
c.mu.RUnlock()
return n
}
// metrics returns a consistent metrics snapshot.
func (c *sessionCache) metrics() Metrics {
c.mu.RLock()
active := uint64(len(c.entries))
c.mu.RUnlock()
return Metrics{
ActiveSessions: active,
TotalCreated: c.totalCreated.Load(),
TotalClosed: c.totalClosed.Load(),
}
}
// ─── Shared helpers ───────────────────────────────────────────────────────────
// idCounter is a global monotonic session ID source.
var idCounter atomic.Uint64
// nextID returns the next unique session ID (1-based).
func nextID() uint64 {
return idCounter.Add(1)
}
+120
View File
@@ -0,0 +1,120 @@
// Package proxy provides minimal, Layer-4 TCP and UDP proxies
// with injected routing decisions (DeciderFunc), structured logging, session
// tracking, and graceful shutdown.
package proxy
import (
"net"
"time"
)
// ─── Public API types ────────────────────────────────────────────────────────
// LocalBinding carries the local-side binding parameters for an outbound proxy
// connection. Both fields are optional and may be set independently.
type LocalBinding struct {
// IP is the local source address to bind the outgoing socket to.
// If nil, the OS selects an appropriate source address.
IP net.IP
// Interface is the name of the network interface (e.g. "eth0") to bind
// the outgoing socket to via SO_BINDTODEVICE (Linux only).
// An empty string disables interface-level binding.
Interface string
}
// DeciderFunc is called once per new session to determine the upstream
// destination and optional local binding parameters for the outgoing socket.
//
// local is the proxy's listen address; peer is the connecting client's address.
//
// It returns:
// - remoteIP: required upstream IP address.
// - remotePort: required upstream port.
// - binding: optional local binding; nil lets the OS choose freely.
// Set binding.IP to pin the source address, binding.Interface to restrict
// the socket to a specific network device (Linux only).
// - extraInfo: optional caller-defined value attached to the session's ConnContext.
// - err: non-nil rejects the session without dialling upstream.
type DeciderFunc func(local net.Addr, peer net.Addr) (remoteIP net.IP, remotePort uint16, binding *LocalBinding, extraInfo any, err error)
// Logger is the minimal structured logging interface expected by the proxies.
// Pass nil to disable all logging.
type Logger interface {
Debug(msg string, args ...any)
Info(msg string, args ...any)
Warn(msg string, args ...any)
Error(msg string, args ...any)
}
// noopLogger silently discards every log message.
type noopLogger struct{}
func (noopLogger) Debug(_ string, _ ...any) {}
func (noopLogger) Info(_ string, _ ...any) {}
func (noopLogger) Warn(_ string, _ ...any) {}
func (noopLogger) Error(_ string, _ ...any) {}
// resolveLogger returns l unchanged if non-nil, otherwise a noopLogger.
func resolveLogger(l Logger) Logger {
if l == nil {
return noopLogger{}
}
return l
}
func resolveLogPrefix(prefix string) string {
if prefix == "" {
return ""
}
return prefix + ": "
}
// ─── Configuration ────────────────────────────────────────────────────────────
// Config holds tunable parameters shared by both proxy types.
type Config struct {
// MaxSessions is the maximum number of concurrent sessions (0 = unlimited).
MaxSessions int
// ReadTimeout closes a session after this duration with no bytes received
// from src. The deadline is rolled forward on every successful read, so
// only truly silent sessions are evicted.
// Constructors default to 5 min for both TCP and UDP.
ReadTimeout time.Duration
// WriteTimeout is the maximum time allowed for a single write to complete
// before the session is torn down. It guards against a stalled destination
// holding a goroutine open indefinitely.
// Constructors default to 30s for TCP and UDP.
WriteTimeout time.Duration
// BufferSize is the size of copy buffers used by TCP pipes (bytes).
// Not used by UDP (UDP always uses 64 KiB buffers to handle max-sized datagrams).
// Each TCP session uses two buffers for bidirectional copying.
// Defaults to 32 KiB when <= 0.
BufferSize int
// DialTimeout is the maximum time the TCP proxy waits when dialling the
// upstream destination for a new session. The dial is also cancelled
// immediately whenever Shutdown is called, regardless of this value.
// Defaults to 10 s when <= 0.
DialTimeout time.Duration
}
const DEFAULT_DIAL_TIMEOUT = 10 * time.Second
const DEFAULT_BUFFER_SIZE = 32 * 1024
const DEFAULT_MAX_SESSIONS = 2048
const DEFAULT_READ_TIMEOUT = 5 * time.Minute
const DEFAULT_WRITE_TIMEOUT = 30 * time.Second
// DefaultConfig returns a sensible default Config.
func DefaultConfig() Config {
return Config{
MaxSessions: DEFAULT_MAX_SESSIONS,
BufferSize: DEFAULT_BUFFER_SIZE,
DialTimeout: DEFAULT_DIAL_TIMEOUT,
ReadTimeout: DEFAULT_READ_TIMEOUT,
WriteTimeout: DEFAULT_WRITE_TIMEOUT,
}
}
+502
View File
@@ -0,0 +1,502 @@
package proxy
import (
"bytes"
"context"
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"
)
// ─── helpers ──────────────────────────────────────────────────────────────────
// passThroughDecider always routes to dest.
func passThroughDecider(dest string) DeciderFunc {
addr, _ := net.ResolveTCPAddr("tcp", dest)
return func(_, _ net.Addr) (net.IP, uint16, *LocalBinding, any, error) {
if addr == nil {
return nil, 0, nil, nil, fmt.Errorf("invalid dest %q", dest)
}
return addr.IP, uint16(addr.Port), nil, nil, nil
}
}
// refuseDecider always rejects sessions.
func refuseDecider(_ net.Addr, _ net.Addr) (net.IP, uint16, *LocalBinding, any, error) {
return nil, 0, nil, nil, fmt.Errorf("rejected")
}
// startTCPEchoServer starts a TCP echo server on a random port.
// It returns the address and a stop function. Accepts testing.TB so it works
// in both tests and benchmarks.
func startTCPEchoServer(tb testing.TB) (addr string, stop func()) {
tb.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
tb.Fatalf("echo server listen: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
io.Copy(c, c) //nolint:errcheck
}(conn)
}
}()
return ln.Addr().String(), func() {
ln.Close()
<-done
}
}
// startUDPEchoServer starts a UDP echo server on a random port.
func startUDPEchoServer(tb testing.TB) (addr string, stop func()) {
tb.Helper()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
if err != nil {
tb.Fatalf("udp echo server listen: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
buf := make([]byte, 64*1024)
for {
n, peer, err := conn.ReadFromUDP(buf)
if err != nil {
return
}
conn.WriteToUDP(buf[:n], peer) //nolint:errcheck
}
}()
return conn.LocalAddr().String(), func() {
conn.Close()
<-done
}
}
// ─── TCP tests ────────────────────────────────────────────────────────────────
func TestTCPProxy_ConnectAndForward(t *testing.T) {
echoAddr, stopEcho := startTCPEchoServer(t)
defer stopEcho()
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewTCPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
conn, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
payload := []byte("hello proxy")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, len(payload))
conn.SetDeadline(time.Now().Add(2 * time.Second)) //nolint:errcheck
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read: %v", err)
}
if !bytes.Equal(buf, payload) {
t.Fatalf("echo mismatch: got %q want %q", buf, payload)
}
}
func TestTCPProxy_BidirectionalBytes(t *testing.T) {
echoAddr, stopEcho := startTCPEchoServer(t)
defer stopEcho()
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewTCPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
const msgSize = 128 * 1024
payload := bytes.Repeat([]byte("X"), msgSize)
conn, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(5 * time.Second)) //nolint:errcheck
errc := make(chan error, 1)
recvd := make([]byte, msgSize)
go func() {
_, err := io.ReadFull(conn, recvd)
errc <- err
}()
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
if err := <-errc; err != nil {
t.Fatalf("read: %v", err)
}
if !bytes.Equal(recvd, payload) {
t.Fatal("bidirectional echo mismatch")
}
}
func TestTCPProxy_SessionCleanupOnClose(t *testing.T) {
echoAddr, stopEcho := startTCPEchoServer(t)
defer stopEcho()
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewTCPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
conn, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
// Wait for the session to register.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if proxy.Metrics().ActiveSessions == 1 {
break
}
time.Sleep(5 * time.Millisecond)
}
if proxy.Metrics().ActiveSessions != 1 {
t.Fatalf("expected 1 active session, got %d", proxy.Metrics().ActiveSessions)
}
conn.Close()
// Wait for cleanup.
deadline = time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if proxy.Metrics().ActiveSessions == 0 {
break
}
time.Sleep(5 * time.Millisecond)
}
if proxy.Metrics().ActiveSessions != 0 {
t.Fatalf("session not cleaned up: active=%d", proxy.Metrics().ActiveSessions)
}
if proxy.Metrics().TotalClosed != 1 {
t.Fatalf("expected TotalClosed=1, got %d", proxy.Metrics().TotalClosed)
}
}
func TestTCPProxy_DeciderRejectsSession(t *testing.T) {
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", refuseDecider, nil, "")
if err != nil {
t.Fatalf("NewTCPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
conn, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(time.Second)) //nolint:errcheck
buf := make([]byte, 4)
_, err = conn.Read(buf)
if err == nil {
t.Fatal("expected connection to be closed by proxy")
}
}
func TestTCPProxy_MaxSessions(t *testing.T) {
echoAddr, stopEcho := startTCPEchoServer(t)
defer stopEcho()
cfg := DefaultConfig()
cfg.MaxSessions = 1
proxy, err := NewTCPProxyWithConfig("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, cfg, "")
if err != nil {
t.Fatalf("NewTCPProxyWithConfig: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
// First connection should succeed and stay open.
c1, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("dial c1: %v", err)
}
defer c1.Close()
// Give the proxy time to accept and register c1.
time.Sleep(50 * time.Millisecond)
// Second connection: proxy should accept TCP but immediately close it.
c2, err := net.DialTimeout("tcp", proxy.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("dial c2: %v", err)
}
defer c2.Close()
c2.SetDeadline(time.Now().Add(time.Second)) //nolint:errcheck
buf := make([]byte, 4)
_, err = c2.Read(buf)
if err == nil {
t.Fatal("expected c2 to be rejected")
}
}
func TestTCPProxy_GracefulShutdown(t *testing.T) {
echoAddr, stopEcho := startTCPEchoServer(t)
defer stopEcho()
proxy, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewTCPProxy: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := proxy.Shutdown(ctx); err != nil {
t.Fatalf("Shutdown: %v", err)
}
}
// ─── UDP tests ────────────────────────────────────────────────────────────────
func TestUDPProxy_SessionCreation(t *testing.T) {
echoAddr, stopEcho := startUDPEchoServer(t)
defer stopEcho()
proxy, err := NewUDPProxy("127.0.0.1:0", "udp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewUDPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
clientConn, err := net.DialUDP("udp", nil,
proxy.Addr().(*net.UDPAddr))
if err != nil {
t.Fatalf("dial: %v", err)
}
defer clientConn.Close()
payload := []byte("hello udp")
clientConn.SetDeadline(time.Now().Add(2 * time.Second)) //nolint:errcheck
if _, err := clientConn.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, 256)
n, err := clientConn.Read(buf)
if err != nil {
t.Fatalf("read: %v", err)
}
if !bytes.Equal(buf[:n], payload) {
t.Fatalf("echo mismatch: got %q want %q", buf[:n], payload)
}
// Session should be registered.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if proxy.Metrics().ActiveSessions == 1 {
break
}
time.Sleep(5 * time.Millisecond)
}
if proxy.Metrics().ActiveSessions != 1 {
t.Fatalf("expected 1 session, got %d", proxy.Metrics().ActiveSessions)
}
}
func TestUDPProxy_ReplyRouting(t *testing.T) {
echoAddr, stopEcho := startUDPEchoServer(t)
defer stopEcho()
proxy, err := NewUDPProxy("127.0.0.1:0", "udp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewUDPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
proxyUDPAddr := proxy.Addr().(*net.UDPAddr)
const numClients = 3
const numMessages = 5
errc := make(chan error, numClients)
for i := 0; i < numClients; i++ {
tag := fmt.Sprintf("client%d", i)
go func(tag string) {
c, err := net.DialUDP("udp", nil, proxyUDPAddr)
if err != nil {
errc <- fmt.Errorf("%s dial: %w", tag, err)
return
}
defer c.Close()
c.SetDeadline(time.Now().Add(3 * time.Second)) //nolint:errcheck
for j := 0; j < numMessages; j++ {
msg := fmt.Sprintf("%s-msg%d", tag, j)
if _, err := c.Write([]byte(msg)); err != nil {
errc <- fmt.Errorf("%s write: %w", tag, err)
return
}
buf := make([]byte, 256)
n, err := c.Read(buf)
if err != nil {
errc <- fmt.Errorf("%s read: %w", tag, err)
return
}
if string(buf[:n]) != msg {
errc <- fmt.Errorf("%s: got %q want %q", tag, buf[:n], msg)
return
}
}
errc <- nil
}(tag)
}
for i := 0; i < numClients; i++ {
if err := <-errc; err != nil {
t.Error(err)
}
}
}
func TestUDPProxy_IdleTimeoutCleanup(t *testing.T) {
echoAddr, stopEcho := startUDPEchoServer(t)
defer stopEcho()
cfg := DefaultConfig()
cfg.ReadTimeout = 200 * time.Millisecond
proxy, err := NewUDPProxyWithConfig("127.0.0.1:0", "udp4", passThroughDecider(echoAddr), nil, cfg, "")
if err != nil {
t.Fatalf("NewUDPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
clientConn, err := net.DialUDP("udp", nil, proxy.Addr().(*net.UDPAddr))
if err != nil {
t.Fatalf("dial: %v", err)
}
defer clientConn.Close()
clientConn.SetDeadline(time.Now().Add(time.Second)) //nolint:errcheck
payload := []byte("trigger session creation")
if _, err := clientConn.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, 256)
if _, err := clientConn.Read(buf); err != nil {
t.Fatalf("initial read: %v", err)
}
// Confirm session is alive.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if proxy.Metrics().ActiveSessions == 1 {
break
}
time.Sleep(5 * time.Millisecond)
}
if proxy.Metrics().ActiveSessions != 1 {
t.Fatal("session did not register")
}
// Let it idle out.
time.Sleep(600 * time.Millisecond)
deadline = time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if proxy.Metrics().ActiveSessions == 0 {
break
}
time.Sleep(10 * time.Millisecond)
}
if proxy.Metrics().ActiveSessions != 0 {
t.Fatalf("idle session not cleaned up: active=%d", proxy.Metrics().ActiveSessions)
}
}
func TestUDPProxy_MaxSessions(t *testing.T) {
// Count how many sessions the decider accepts; reject beyond limit.
var accepted atomic.Int32
const limit = 2
decider := func(local, peer net.Addr) (net.IP, uint16, *LocalBinding, any, error) {
if accepted.Load() >= limit {
return nil, 0, nil, nil, fmt.Errorf("max sessions")
}
accepted.Add(1)
return nil, 0, nil, nil, fmt.Errorf("no upstream needed for this test")
}
cfg := DefaultConfig()
cfg.MaxSessions = limit
proxy, err := NewUDPProxyWithConfig("127.0.0.1:0", "udp4", decider, nil, cfg, "")
if err != nil {
t.Fatalf("NewUDPProxy: %v", err)
}
defer proxy.Shutdown(context.Background()) //nolint:errcheck
// The proxy itself enforces MaxSessions, so the decider may or may not
// be called for the first 'limit' packets. Just verify the proxy starts.
if proxy.Addr() == nil {
t.Fatal("proxy has no address")
}
}
func TestUDPProxy_GracefulShutdown(t *testing.T) {
echoAddr, stopEcho := startUDPEchoServer(t)
defer stopEcho()
proxy, err := NewUDPProxy("127.0.0.1:0", "udp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("NewUDPProxy: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := proxy.Shutdown(ctx); err != nil {
t.Fatalf("Shutdown: %v", err)
}
}
// ─── Misc ─────────────────────────────────────────────────────────────────────
func TestNilLogger(t *testing.T) {
echoAddr, stopEcho := startTCPEchoServer(t)
defer stopEcho()
_, err := NewTCPProxy("127.0.0.1:0", "tcp4", passThroughDecider(echoAddr), nil, "")
if err != nil {
t.Fatalf("nil Logger should be accepted but got: %v", err)
}
}
func TestNilDeciderRejected(t *testing.T) {
if _, err := NewTCPProxy("127.0.0.1:0", "tcp4", nil, nil, ""); err == nil {
t.Fatal("nil DeciderFunc should be rejected")
}
if _, err := NewUDPProxy("127.0.0.1:0", "udp4", nil, nil, ""); err == nil {
t.Fatal("nil DeciderFunc should be rejected")
}
}
func TestMetricsString(t *testing.T) {
m := Metrics{ActiveSessions: 3, TotalCreated: 10, TotalClosed: 7}
s := m.String()
if s == "" {
t.Fatal("Metrics.String() returned empty string")
}
}
+328
View File
@@ -0,0 +1,328 @@
package proxy
import (
"context"
"errors"
"io"
"net"
"os"
"sync"
"time"
)
// TCPProxy is a Layer-4 TCP proxy that routes each accepted connection through
// a DeciderFunc before dialling the upstream destination.
//
// It is safe to call Shutdown from any goroutine and from multiple goroutines
// simultaneously (only the first call has effect).
type TCPProxy struct {
decider DeciderFunc
log Logger
cfg Config
network string
logPrefix string
listener net.Listener
bufPool sync.Pool
cache *sessionCache
shutdownCtx context.Context
shutdown context.CancelFunc
once sync.Once
wg sync.WaitGroup
}
// NewTCPProxy creates and starts a TCP proxy that listens on listenAddr.
// It uses DefaultConfig for tuning parameters.
//
// The proxy begins accepting connections immediately; call Shutdown to stop it.
//
// Parameters:
// - listenAddr: the local address to listen on (e.g. "0.0.0.0:719")
// - network: the network type to listen on (e.g. "tcp4", "tcp6", "udp4", "udp6")
// - decider: a function that determines the upstream destination for each
// accepted connection. See DeciderFunc for details.
// - logger: an optional Logger for debug/info/warn messages. If nil, a
// default logger is used.
// - logPrefix: a string prepended to every log message (e.g. "tcp proxy IPv4").
// Pass an empty string to log messages without a prefix.
func NewTCPProxy(listenAddr string, network string, decider DeciderFunc, logger Logger, logPrefix string) (*TCPProxy, error) {
return NewTCPProxyWithConfig(listenAddr, network, decider, logger, DefaultConfig(), logPrefix)
}
// NewTCPProxyWithConfig is like NewTCPProxy but accepts a custom Config.
func NewTCPProxyWithConfig(listenAddr string, network string, decider DeciderFunc, logger Logger, cfg Config, logPrefix string) (*TCPProxy, error) {
if decider == nil {
return nil, errors.New("proxy: decider must not be nil")
}
if cfg.BufferSize <= 0 {
cfg.BufferSize = DEFAULT_BUFFER_SIZE
}
if cfg.DialTimeout <= 0 {
cfg.DialTimeout = DEFAULT_DIAL_TIMEOUT
}
if cfg.ReadTimeout <= 0 {
cfg.ReadTimeout = DEFAULT_READ_TIMEOUT
}
if cfg.WriteTimeout <= 0 {
cfg.WriteTimeout = DEFAULT_WRITE_TIMEOUT
}
ln, err := net.Listen(network, listenAddr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
p := &TCPProxy{
decider: decider,
log: resolveLogger(logger),
cfg: cfg,
network: network,
logPrefix: resolveLogPrefix(logPrefix),
listener: ln,
cache: newSessionCache(),
shutdownCtx: ctx,
shutdown: cancel,
}
bufSize := cfg.BufferSize
p.bufPool.New = func() interface{} {
b := make([]byte, bufSize)
return &b
}
p.wg.Add(1)
go p.acceptLoop()
p.log.Debug(p.logPrefix+"listening", "addr", ln.Addr())
return p, nil
}
// Addr returns the address the proxy is listening on.
func (p *TCPProxy) Addr() net.Addr {
return p.listener.Addr()
}
// Metrics returns a snapshot of the session cache statistics.
func (p *TCPProxy) Metrics() Metrics {
return p.cache.metrics()
}
// FindProxiedEgressConnection returns all active (including establishing)
// sessions whose upstream destination matches destIP and destPort.
// Returns nil if no matching session exists.
func (p *TCPProxy) FindProxiedEgressConnection(destIP net.IP, destPort uint16) []*ConnContext {
return p.cache.findByDest(destIP, destPort)
}
// HasProxiedEgressConnection checks if there is an active session whose
// upstream destination matches destIP and destPort.
func (p *TCPProxy) HasProxiedEgressConnection(destIP net.IP, destPort uint16) bool {
return p.cache.hasByDest(destIP, destPort)
}
// Shutdown stops accepting new connections, signals all active sessions to
// close, and waits for them to finish. If ctx expires before all sessions
// drain, it returns ctx.Err() but does not leak goroutines — connections are
// already being forcibly closed by the context cancellation.
func (p *TCPProxy) Shutdown(ctx context.Context) error {
var retErr error
p.once.Do(func() {
p.shutdown()
p.listener.Close()
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
select {
case <-done:
p.log.Debug(p.logPrefix+"shutdown complete", "metrics", p.cache.metrics())
case <-ctx.Done():
retErr = ctx.Err()
p.log.Warn(p.logPrefix+"forced shutdown", "err", retErr)
}
})
return retErr
}
// ─── accept loop ──────────────────────────────────────────────────────────────
func (p *TCPProxy) acceptLoop() {
defer p.wg.Done()
var backoff time.Duration
for {
conn, err := p.listener.Accept()
if err != nil {
select {
case <-p.shutdownCtx.Done():
return
default:
// Transient OS error (e.g. EMFILE). Back off exponentially so
// a sustained error produces at most ~1 log line per second.
if backoff == 0 {
backoff = 5 * time.Millisecond
} else {
backoff = min(backoff*2, time.Second)
}
p.log.Error(p.logPrefix+"accept error", "err", err)
time.Sleep(backoff)
continue
}
}
backoff = 0 // reset on success
if p.cfg.MaxSessions > 0 && p.cache.len() >= p.cfg.MaxSessions {
p.log.Warn(p.logPrefix+"max sessions reached, rejecting connection", "max", p.cfg.MaxSessions, "addr", conn.RemoteAddr())
conn.Close()
continue
}
p.wg.Add(1)
go p.handleConn(conn)
}
}
// ─── per-connection handler ───────────────────────────────────────────────────
func (p *TCPProxy) handleConn(clientConn net.Conn) {
defer p.wg.Done()
defer clientConn.Close()
// Determine upstream destination.
destIP, destPort, binding, extraInfo, err := p.decider(p.listener.Addr(), clientConn.RemoteAddr())
if err != nil {
p.log.Warn(p.logPrefix+"decider rejected connection", "addr", clientConn.RemoteAddr(), "err", err)
return
}
destAddr := (&net.TCPAddr{IP: destIP, Port: int(destPort)}).String()
// Register the session immediately so FindProxiedEgressConnection can
// locate it before the upstream dial completes.
sessCtx, cancel := context.WithCancel(p.shutdownCtx)
connCtx := newConnContext(
nextID(),
clientConn.RemoteAddr(),
destIP,
destPort,
cancel,
extraInfo,
)
p.cache.add(connCtx)
defer func() {
cancel()
p.cache.remove(connCtx)
p.log.Debug(p.logPrefix+"session closed", "session", connCtx.id, "dest_ip", connCtx.destIP, "dest_port", connCtx.destPort, "bytes_in", connCtx.BytesIn.Load(), "bytes_out", connCtx.BytesOut.Load())
}()
// DialContext is cancelled immediately if the proxy is shut down.
dialer := net.Dialer{Timeout: p.cfg.DialTimeout}
if binding != nil && binding.IP != nil {
dialer.LocalAddr = &net.TCPAddr{IP: binding.IP}
}
if binding != nil {
applyBindToDevice(&dialer, binding.Interface)
}
upstreamConn, err := dialer.DialContext(p.shutdownCtx, p.network, destAddr)
if err != nil {
if p.shutdownCtx.Err() != nil {
// Proxy is shutting down; this is expected, not an error.
return
}
p.log.Error(p.logPrefix+"dial failed", "addr", destAddr, "err", err)
return
}
defer upstreamConn.Close()
p.log.Debug(p.logPrefix+"session started", "session", connCtx.id, "from", clientConn.RemoteAddr(), "to", destAddr)
// Watchdog: when the proxy shuts down (or the caller cancels the session),
// force-close both ends so the copy goroutines unblock immediately.
go func() {
<-sessCtx.Done()
clientConn.Close()
upstreamConn.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
// client → upstream
go func() {
defer wg.Done()
n := p.pipe(upstreamConn, clientConn, connCtx)
connCtx.BytesOut.Add(uint64(n))
// Propagate clean EOF downstream.
halfClose(upstreamConn)
}()
// upstream → client
go func() {
defer wg.Done()
n := p.pipe(clientConn, upstreamConn, connCtx)
connCtx.BytesIn.Add(uint64(n))
halfClose(clientConn)
}()
wg.Wait()
}
// pipe copies from src to dst using a manual read/write loop with a pooled
// buffer and returns the total bytes transferred.
//
// io.CopyBuffer is not used because it provides no opportunity
// to call SetReadDeadline/SetWriteDeadline between individual I/O operations.
func (p *TCPProxy) pipe(dst, src net.Conn, connCtx *ConnContext) int64 {
bp := p.bufPool.Get().(*[]byte)
defer p.bufPool.Put(bp)
buf := *bp
var total int64
for {
_ = src.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
nr, readErr := src.Read(buf)
if nr > 0 {
connCtx.touch() // session is active; reset idle tracking
_ = dst.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
nw, writeErr := dst.Write(buf[:nr])
total += int64(nw)
if writeErr != nil {
if errors.Is(writeErr, os.ErrDeadlineExceeded) {
p.log.Debug(p.logPrefix+"session write timeout", "session", connCtx.id, "dest_ip", connCtx.destIP, "dest_port", connCtx.destPort, "timeout", p.cfg.WriteTimeout)
} else if !isClosedConnErr(writeErr) {
p.log.Debug(p.logPrefix+"session write error", "session", connCtx.id, "dest_ip", connCtx.destIP, "dest_port", connCtx.destPort, "err", writeErr)
}
break
}
}
if readErr != nil {
if errors.Is(readErr, os.ErrDeadlineExceeded) {
p.log.Debug(p.logPrefix+"session read timeout", "session", connCtx.id, "dest_ip", connCtx.destIP, "dest_port", connCtx.destPort, "timeout", p.cfg.ReadTimeout)
} else if !isClosedConnErr(readErr) {
p.log.Debug(p.logPrefix+"session read error", "session", connCtx.id, "dest_ip", connCtx.destIP, "dest_port", connCtx.destPort, "err", readErr)
}
break
}
}
return total
}
// halfClose attempts a TCP write-shutdown so the peer receives EOF on its
// read side while the connection stays open for the other direction.
func halfClose(conn net.Conn) {
type canCloseWrite interface{ CloseWrite() error }
if c, ok := conn.(canCloseWrite); ok {
_ = c.CloseWrite()
}
}
// isClosedConnErr reports whether err is a clean EOF or a closed-socket error
// that is expected during normal session teardown or proxy shutdown.
func isClosedConnErr(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)
}
+439
View File
@@ -0,0 +1,439 @@
package proxy
import (
"context"
"errors"
"net"
"sync"
"time"
)
// udpSession is the NAT entry for a single client endpoint.
type udpSession struct {
connCtx *ConnContext
// remote is the per-session UDP socket dialled to the upstream.
// net.Conn is used so platform-specific dialers (e.g. SO_BINDTODEVICE on
// Linux) can return different concrete types without changing the callers.
remote net.Conn
}
// UDPProxy is a Layer-4 UDP proxy. It uses a single listening UDPConn and
// maintains a NAT-like table that maps each client address to a dedicated
// upstream socket. Sessions are evicted after Config.ReadTimeout of inactivity
// (default 5 min).
type UDPProxy struct {
decider DeciderFunc
log Logger
cfg Config
logPrefix string
conn *net.UDPConn // listener
cache *sessionCache
// sessions maps clientAddr.String() → *udpSession
mu sync.RWMutex
sessions map[string]*udpSession
shutdownCtx context.Context
shutdown context.CancelFunc
once sync.Once
wg sync.WaitGroup
}
// udpBufPool holds reusable 64 KiB byte slices for UDP datagram I/O.
// 64 KiB is the maximum size of a UDP payload, so this size avoids fragmentation
// for any datagram. The pool amortizes the cost of allocating these buffers,
// which are large enough to trigger GC pressure if allocated on every packet.
var udpBufPool = sync.Pool{
New: func() interface{} {
b := make([]byte, 64*1024)
return &b
},
}
// NewUDPProxy creates and starts a UDP proxy listening on listenAddr.
// It uses DefaultConfig for tuning parameters.
//
// Parameters:
// - listenAddr: the local address to listen on (e.g. "0.0.0.0:719")
// - network: the network type to listen on (e.g. "udp4", "udp6")
// - decider: a function that determines the upstream destination for each
// accepted connection. See DeciderFunc for details.
// - logger: an optional Logger for debug/info/warn messages. If nil, a
// default logger is used.
// - logPrefix: a string prepended to every log message (e.g. "udp proxy IPv4").
// Pass an empty string to log messages without a prefix.
func NewUDPProxy(listenAddr string, network string, decider DeciderFunc, logger Logger, logPrefix string) (*UDPProxy, error) {
return NewUDPProxyWithConfig(listenAddr, network, decider, logger, DefaultConfig(), logPrefix)
}
// NewUDPProxyWithConfig is like NewUDPProxy but accepts a custom Config.
func NewUDPProxyWithConfig(listenAddr string, network string, decider DeciderFunc, logger Logger, cfg Config, logPrefix string) (*UDPProxy, error) {
if decider == nil {
return nil, errors.New("proxy: decider must not be nil")
}
if cfg.ReadTimeout <= 0 {
cfg.ReadTimeout = DEFAULT_READ_TIMEOUT
}
if cfg.WriteTimeout <= 0 {
cfg.WriteTimeout = DEFAULT_WRITE_TIMEOUT
}
addr, err := net.ResolveUDPAddr(network, listenAddr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP(network, addr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
p := &UDPProxy{
decider: decider,
log: resolveLogger(logger),
cfg: cfg,
logPrefix: resolveLogPrefix(logPrefix),
conn: conn,
cache: newSessionCache(),
sessions: make(map[string]*udpSession, 64),
shutdownCtx: ctx,
shutdown: cancel,
}
p.wg.Add(2)
go p.readLoop()
go p.cleanupLoop()
p.log.Debug(p.logPrefix+"listening", "addr", conn.LocalAddr())
return p, nil
}
// Addr returns the address the proxy is listening on.
func (p *UDPProxy) Addr() net.Addr {
return p.conn.LocalAddr()
}
// Metrics returns a snapshot of the session cache statistics.
func (p *UDPProxy) Metrics() Metrics {
return p.cache.metrics()
}
// FindProxiedEgressConnection returns all active (including establishing)
// sessions whose upstream destination matches destIP and destPort.
// Returns nil if no matching session exists.
func (p *UDPProxy) FindProxiedEgressConnection(destIP net.IP, destPort uint16) []*ConnContext {
return p.cache.findByDest(destIP, destPort)
}
// HasProxiedEgressConnection checks if there is an active session whose
// upstream destination matches destIP and destPort.
func (p *UDPProxy) HasProxiedEgressConnection(destIP net.IP, destPort uint16) bool {
return p.cache.hasByDest(destIP, destPort)
}
// Shutdown tears down the proxy. It closes the listen socket, cancels all
// sessions, and waits for goroutines to exit or until ctx expires.
func (p *UDPProxy) Shutdown(ctx context.Context) error {
var retErr error
p.once.Do(func() {
// Signal all goroutines and unblock ReadFromUDP.
p.shutdown()
p.conn.Close()
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
select {
case <-done:
p.log.Debug(p.logPrefix+"shutdown complete", "metrics", p.cache.metrics())
case <-ctx.Done():
retErr = ctx.Err()
p.log.Warn(p.logPrefix+"forced shutdown", "err", retErr)
}
})
return retErr
}
// ─── Inbound read loop ────────────────────────────────────────────────────────
func (p *UDPProxy) readLoop() {
defer p.wg.Done()
var backoff time.Duration
for {
bp := udpBufPool.Get().(*[]byte)
n, clientAddr, err := p.conn.ReadFromUDP(*bp)
if err != nil {
udpBufPool.Put(bp)
select {
case <-p.shutdownCtx.Done():
return
default:
if errors.Is(err, net.ErrClosed) {
if p.shutdownCtx.Err() == nil {
p.log.Error(p.logPrefix+"socket unexpectedly closed", "err", err)
}
return
}
// Transient (e.g. ENOBUFS, ICMP-delivered ECONNREFUSED).
// Back off exponentially so a sustained error produces at
// most ~1 log line per second.
if backoff == 0 {
backoff = 5 * time.Millisecond
} else {
backoff = min(backoff*2, time.Second)
}
p.log.Error(p.logPrefix+"read error", "err", err)
time.Sleep(backoff)
continue
}
}
backoff = 0 // reset on success
// Pass the slice directly to handlePacket — it uses the data
// synchronously (all Write calls complete before it returns), so
// we can return the buffer to the pool immediately after.
p.handlePacket(clientAddr, (*bp)[:n])
udpBufPool.Put(bp)
}
}
// handlePacket routes one inbound datagram to the correct upstream session.
func (p *UDPProxy) handlePacket(clientAddr *net.UDPAddr, data []byte) {
key := clientAddr.String()
// Fast path: session already exists.
p.mu.RLock()
sess, ok := p.sessions[key]
p.mu.RUnlock()
if ok {
sess.connCtx.touch()
_ = sess.remote.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if _, err := sess.remote.Write(data); err != nil {
if !isClosedConnErr(err) {
p.log.Warn(p.logPrefix+"write to upstream failed", "client", key, "err", err)
}
return
}
sess.connCtx.PacketsOut.Add(1)
sess.connCtx.BytesOut.Add(uint64(len(data)))
return
}
// Slow path: new client — enforce session limit before allocating.
if p.cfg.MaxSessions > 0 && p.cache.len() >= p.cfg.MaxSessions {
p.log.Warn(p.logPrefix+"max sessions reached, dropping packet", "max", p.cfg.MaxSessions, "client", key)
return
}
destIP, destPort, binding, extraInfo, err := p.decider(p.conn.LocalAddr(), clientAddr)
if err != nil {
p.log.Warn(p.logPrefix+"decider rejected connection", "client", key, "err", err)
return
}
// Register the session immediately so FindProxiedEgressConnection can
// locate it before the upstream dial completes.
sessCtx, cancel := context.WithCancel(p.shutdownCtx)
connCtx := newConnContext(nextID(), clientAddr, destIP, destPort, cancel, extraInfo)
p.cache.add(connCtx)
remoteAddr := &net.UDPAddr{IP: destIP, Port: int(destPort)}
d := net.Dialer{}
if binding != nil && binding.IP != nil {
d.LocalAddr = &net.UDPAddr{IP: binding.IP}
}
if binding != nil {
applyBindToDevice(&d, binding.Interface)
}
remoteConn, err := d.DialContext(sessCtx, "udp", remoteAddr.String())
if err != nil {
p.cache.remove(connCtx)
cancel()
p.log.Error(p.logPrefix+"dial failed", "addr", remoteAddr, "err", err)
return
}
sess = &udpSession{connCtx: connCtx, remote: remoteConn}
// Write-lock: check again to prevent duplicate sessions under contention.
p.mu.Lock()
if existing, exists := p.sessions[key]; exists {
p.mu.Unlock()
cancel()
remoteConn.Close()
p.cache.remove(connCtx) // undo early registration; use the existing session
// Reuse the existing session for this datagram.
existing.connCtx.touch()
_ = existing.remote.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if _, err := existing.remote.Write(data); err != nil {
p.log.Warn(p.logPrefix+"write to upstream failed", "client", key, "err", err)
return
}
existing.connCtx.PacketsOut.Add(1)
existing.connCtx.BytesOut.Add(uint64(len(data)))
return
}
p.sessions[key] = sess
p.mu.Unlock()
p.log.Debug(p.logPrefix+"session started", "session", connCtx.id, "from", key, "to", remoteAddr)
// Launch reverse-direction goroutine (upstream → client).
p.wg.Add(1)
go p.forwardFromRemote(sessCtx, sess, clientAddr)
// Forward the first datagram.
connCtx.touch()
_ = remoteConn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if _, err := remoteConn.Write(data); err != nil {
p.log.Warn(p.logPrefix+"initial write to upstream failed", "err", err)
return
}
connCtx.PacketsOut.Add(1)
connCtx.BytesOut.Add(uint64(len(data)))
}
// ─── Upstream → client forwarder ─────────────────────────────────────────────
// forwardFromRemote reads replies from the upstream socket and writes them
// back to the originating client. It exits when the context is cancelled,
// an idle timeout fires, or an unrecoverable read/write error occurs.
func (p *UDPProxy) forwardFromRemote(ctx context.Context, sess *udpSession, clientAddr *net.UDPAddr) {
defer p.wg.Done()
defer p.removeSession(sess, clientAddr.String())
for {
// Check for cancellation before each read.
select {
case <-ctx.Done():
return
default:
}
// Roll the read deadline before every read so a truly silent upstream
// is detected within ReadTimeout.
_ = sess.remote.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
bp := udpBufPool.Get().(*[]byte)
n, err := sess.remote.Read(*bp)
if err != nil {
udpBufPool.Put(bp)
select {
case <-ctx.Done():
return
default:
if isTimeoutErr(err) {
p.log.Debug(p.logPrefix+"session idle timeout", "session", sess.connCtx.id, "dest_ip", sess.connCtx.destIP, "dest_port", sess.connCtx.destPort, "client", clientAddr)
return
}
if !isClosedConnErr(err) {
p.log.Debug(p.logPrefix+"session read error", "session", sess.connCtx.id, "dest_ip", sess.connCtx.destIP, "dest_port", sess.connCtx.destPort, "err", err)
}
return
}
}
// Write to listen socket and return buffer. WriteToUDP is safe to
// call concurrently on the same *net.UDPConn; each goroutine resets
// the write deadline immediately before its own write, so concurrent
// sessions may shift each other's deadline by at most WriteTimeout.
_ = p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
_, writeErr := p.conn.WriteToUDP((*bp)[:n], clientAddr)
udpBufPool.Put(bp)
if writeErr != nil {
select {
case <-ctx.Done():
return
default:
if !isClosedConnErr(writeErr) {
p.log.Warn(p.logPrefix+"write to client failed", "client", clientAddr, "err", writeErr)
}
return
}
}
sess.connCtx.touch()
sess.connCtx.PacketsIn.Add(1)
sess.connCtx.BytesIn.Add(uint64(n))
}
}
// removeSession evicts sess from the NAT table and the session cache.
func (p *UDPProxy) removeSession(sess *udpSession, key string) {
sess.remote.Close()
sess.connCtx.Close()
p.mu.Lock()
delete(p.sessions, key)
p.mu.Unlock()
p.cache.remove(sess.connCtx)
p.log.Debug(p.logPrefix+"session removed", "session", sess.connCtx.id, "dest_ip", sess.connCtx.destIP, "dest_port", sess.connCtx.destPort, "bytes_in", sess.connCtx.BytesIn.Load(), "bytes_out", sess.connCtx.BytesOut.Load())
}
// ─── Idle cleanup loop ────────────────────────────────────────────────────────
// cleanupLoop periodically inspects the NAT table and cancels sessions whose
// last-seen time predates the idle timeout.
func (p *UDPProxy) cleanupLoop() {
defer p.wg.Done()
interval := p.cfg.ReadTimeout / 2
if interval < time.Second*10 {
interval = time.Second * 10
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-p.shutdownCtx.Done():
// Cancel all sessions and close their remote sockets so that
// forwardFromRemote goroutines unblock from Read immediately.
p.mu.Lock()
for _, sess := range p.sessions {
sess.connCtx.Close()
sess.remote.Close() // unblocks pending Read in forwardFromRemote
}
p.mu.Unlock()
return
case <-ticker.C:
p.evictIdle()
}
}
}
// evictIdle cancels sessions that have been idle longer than ReadTimeout.
// It closes the remote socket so forwardFromRemote's Read unblocks faster
// than waiting for the next rolling deadline.
func (p *UDPProxy) evictIdle() {
threshold := time.Now().Add(-p.cfg.ReadTimeout)
p.mu.RLock()
defer p.mu.RUnlock()
for key, sess := range p.sessions {
if sess.connCtx.LastSeen().Before(threshold) {
p.log.Debug(p.logPrefix+"evicting idle session", "session", sess.connCtx.ID(), "client", key)
sess.connCtx.Close()
// Wake the blocked Read so the goroutine notices ctx.Done().
_ = sess.remote.SetReadDeadline(time.Now())
}
}
}
// ─── Error helpers ────────────────────────────────────────────────────────────
// isTimeoutErr returns true if err is a network timeout error.
func isTimeoutErr(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
+188
View File
@@ -0,0 +1,188 @@
package splittun
import (
"fmt"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/safing/portmaster/service/mgr"
"github.com/safing/portmaster/service/netenv"
"github.com/safing/portmaster/service/network"
"github.com/safing/portmaster/service/network/packet"
"github.com/safing/portmaster/service/splittun/proxy"
)
// pendingRequestTTL is the maximum time a pending request waits for the proxy
// to accept the redirected connection. If the OS drops/resets the connection
// before it reaches the proxy, the entry would otherwise leak indefinitely.
const pendingRequestTTL = 30 * time.Second
type request struct {
connInfo *network.Connection
binding proxy.LocalBinding
expiresAt time.Time
}
var (
requestsLock sync.Mutex
pendingRequests map[string]*request = make(map[string]*request) // key: "localIP:localPort"
cleanupScheduled atomic.Bool
)
// AwaitRequest registers a connection for handling when it arrives at the proxy.
// The bindInterface must be unique info which identifies the interface to bind to:
// - interface local IP address (e.g. "192.168.1.1")
// - interface name (e.g. "eth0")
// - MAC address (e.g. "00:1A:2B:3C:4D:5E")
// - empty - to try detecting "default" (non-VPN) interface automatically (not reliable)
func AwaitRequest(connInfo *network.Connection, bindInterface string) (*network.SplitTunContext, error) {
var binding proxy.LocalBinding
if bindInterface == "" {
// empty - is the default and means to try detecting the "default" (non-VPN) interface automatically.
// This is not reliable, but can be convenient for users who don't want to configure an interface.
ifaces, err := netenv.GetBestPhysicalDefaultInterfaces()
if err != nil {
return nil, err
}
var selectedIface *netenv.InterfaceInfo
if connInfo.IPVersion == packet.IPv6 && ifaces.ForIPv6 != nil {
selectedIface = ifaces.ForIPv6
binding.IP = selectedIface.IPv6
} else if connInfo.IPVersion == packet.IPv4 && ifaces.ForIPv4 != nil {
selectedIface = ifaces.ForIPv4
binding.IP = selectedIface.IPv4
} else {
return nil, fmt.Errorf("no suitable default physical interface found for %s", connInfo.IPVersion)
}
binding.Interface = selectedIface.Interface.Name
} else {
// Getting the interface IP address to bind the proxy connection to.
iface, err := netenv.GetInterface(bindInterface)
if err != nil {
return nil, err
}
if connInfo.IPVersion == packet.IPv6 {
binding.IP = iface.IPv6
} else {
binding.IP = iface.IPv4
}
if binding.IP == nil {
return nil, fmt.Errorf("interface %q has no usable address for %s", bindInterface, connInfo.IPVersion)
}
binding.Interface = iface.Interface.Name
}
// Create unique key for the pending connection
if connInfo.LocalIP == nil {
return nil, fmt.Errorf("connection has no local IP")
}
key := net.JoinHostPort(connInfo.LocalIP.String(), strconv.Itoa(int(connInfo.LocalPort)))
requestsLock.Lock()
defer requestsLock.Unlock()
// Register the request
if _, exists := pendingRequests[key]; exists {
return nil, fmt.Errorf("a pending request for %s already exists", key)
}
pendingRequests[key] = &request{
connInfo: connInfo,
binding: binding,
expiresAt: time.Now().Add(pendingRequestTTL),
}
// Schedule deferred cleanup outside of the hot path.
// The goroutine only starts if none is already running.
scheduleCleanup()
return &network.SplitTunContext{
Interface: binding.Interface,
IP: binding.IP,
}, nil
}
// scheduleCleanup starts a deferred cleanup goroutine if one is not already
// running. The goroutine wakes after pendingRequestTTL+1s, sweeps expired
// entries, and reschedules itself if unexpired entries remain. It exits
// immediately when the module's manager context is cancelled.
func scheduleCleanup() {
if !cleanupScheduled.CompareAndSwap(false, true) {
return // already scheduled; it will sweep our entry too
}
module.mgr.Go("pending-requests-cleanup", func(w *mgr.WorkerCtx) error {
select {
case <-w.Done():
cleanupScheduled.Store(false)
return nil
case <-time.After(pendingRequestTTL + time.Second):
}
requestsLock.Lock()
sweepPendingRequestsLocked()
nonEmpty := len(pendingRequests) > 0
requestsLock.Unlock()
// Reset flag before potential reschedule to avoid a gap where
// a concurrent AwaitRequest could miss starting a new goroutine.
cleanupScheduled.Store(false)
if nonEmpty {
scheduleCleanup()
}
return nil
})
}
// sweepPendingRequestsLocked removes any pending requests that have exceeded
// the TTL. The caller must hold requestsLock.
func sweepPendingRequestsLocked() {
now := time.Now()
for key, r := range pendingRequests {
if now.After(r.expiresAt) {
delete(pendingRequests, key)
}
}
}
// clearPendingRequests removes all pending requests. Called on module stop.
func clearPendingRequests() {
requestsLock.Lock()
pendingRequests = make(map[string]*request)
requestsLock.Unlock()
}
// consumeRequest retrieves and removes a pending request for the given address.
// Returns an error if the request has expired.
func consumeRequest(address string) (r *request, err error) {
requestsLock.Lock()
r, ok := pendingRequests[address]
if ok {
delete(pendingRequests, address)
requestsLock.Unlock()
if time.Now().After(r.expiresAt) {
return nil, fmt.Errorf("pending request for %s has expired", address)
}
return r, nil
}
requestsLock.Unlock()
return nil, fmt.Errorf("no pending request for %s", address)
}
// proxyDecider is called by the proxy for each new connection to determine the
// upstream destination and local binding parameters.
func proxyDecider(local net.Addr, peer net.Addr) (remoteIP net.IP, remotePort uint16, binding *proxy.LocalBinding, extraInfo any, err error) {
r, err := consumeRequest(peer.String())
if err != nil {
return nil, 0, nil, nil, err
}
return r.connInfo.Entity.IP, uint16(r.connInfo.Entity.Port), &r.binding, r.connInfo, nil
}
+12 -8
View File
@@ -92,15 +92,19 @@ var (
`,
},
{
Name: "Safing Support",
ID: string(account.FeatureSafingSupport),
RequiredFeatureID: account.FeatureSafingSupport,
InPackage: packagePlus,
Name: "Split Tunneling",
ID: "splittun",
ConfigKey: "splittun/enable",
ConfigScope: "splittun/",
InPackage: packageFree,
icon: `
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round"
d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" />
</svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" stroke="currentColor" fill="none">
<g stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5">
<path d="M8.5 12h3" />
<path d="M11 12c4-2 6-4 6.5-4.5" /> <path d="M11 12c4 2 6 4 6.5 4.5" />
<circle cx="19" cy="6" r="2" /> <circle cx="19" cy="18" r="2" /> <circle cx="4" cy="12" r="2" />
</g>
</svg>
`,
},
{
+4 -13
View File
@@ -7,7 +7,7 @@ use wdk::filter_engine::packet::InjectInfo;
use crate::connection::{
Connection, ConnectionV4, ConnectionV6, Direction, RedirectInfo, Verdict, PM_DNS_PORT,
PM_SPN_PORT,
PM_SPN_PORT, PM_SPLIT_TUN_PORT,
};
use crate::connection_cache::ConnectionCache;
use crate::connection_map::Key;
@@ -88,18 +88,9 @@ impl ConnectionInfo {
}
}
fn fast_track_pm_packets(key: &Key, direction: Direction) -> bool {
match direction {
Direction::Outbound => {
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT {
return key.local_address == key.remote_address;
}
}
Direction::Inbound => {
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT {
return key.local_address == key.remote_address;
}
}
fn fast_track_pm_packets(key: &Key, _: Direction) -> bool {
if key.local_port == PM_DNS_PORT || key.local_port == PM_SPN_PORT || key.local_port == PM_SPLIT_TUN_PORT {
return key.local_address == key.remote_address;
}
return false;
+1
View File
@@ -36,6 +36,7 @@ const (
VerdictRerouteToNameserver KextVerdict = 8
VerdictRerouteToTunnel KextVerdict = 9
VerdictFailed KextVerdict = 10
VerdictRerouteToSplitTun KextVerdict = 11
)
type Verdict struct {
+1 -1
View File
@@ -1 +1 @@
[2, 1, 0, 0]
[2, 1, 1, 0]