From e79f31db71aa05b8713b9c73f4f8c9b2374862f2 Mon Sep 17 00:00:00 2001 From: Jianglong Date: Thu, 14 Sep 2023 02:35:16 +0800 Subject: [PATCH] feat: add up and down speed limits (#54) Co-authored-by: Jianglong --- conn.go | 15 +++++++++++++-- go.mod | 1 + go.sum | 2 ++ settings.go | 6 +++++- tproxy.go | 4 +++- 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 692c882..d94c894 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "time" "github.com/fatih/color" + "github.com/juju/ratelimit" "github.com/kevwan/tproxy/display" "github.com/kevwan/tproxy/protocol" ) @@ -58,7 +59,12 @@ func (c *PairedConnection) handleClientMessage() { r, w := io.Pipe() tee := io.MultiWriter(c.svrConn, w) go protocol.CreateInterop(settings.Protocol).Dump(r, protocol.ClientSide, c.id, settings.Quiet) - c.copyData(tee, c.cliConn, protocol.ClientSide) + var src io.Reader = c.cliConn + if settings.UpLimit > 0 { + bucket := ratelimit.NewBucket(time.Second, settings.UpLimit) + src = ratelimit.Reader(src, bucket) + } + c.copyData(tee, src, protocol.ClientSide) } func (c *PairedConnection) handleServerMessage() { @@ -68,7 +74,12 @@ func (c *PairedConnection) handleServerMessage() { r, w := io.Pipe() tee := io.MultiWriter(newDelayedWriter(c.cliConn, settings.Delay, c.stopChan), w) go protocol.CreateInterop(settings.Protocol).Dump(r, protocol.ServerSide, c.id, settings.Quiet) - c.copyData(tee, c.svrConn, protocol.ServerSide) + var src io.Reader = c.svrConn + if settings.DownLimit > 0 { + bucket := ratelimit.NewBucket(time.Second, settings.DownLimit) + src = ratelimit.Reader(src, bucket) + } + c.copyData(tee, src, protocol.ServerSide) } func (c *PairedConnection) process() { diff --git a/go.mod b/go.mod index 9511bd7..46983cb 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( ) require ( + github.com/juju/ratelimit v1.0.2 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect diff --git a/go.sum b/go.sum index 6ae5197..aed1cac 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI= +github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= diff --git a/settings.go b/settings.go index 34e273f..e35e3b3 100644 --- a/settings.go +++ b/settings.go @@ -10,10 +10,12 @@ type Settings struct { Protocol string Stat bool Quiet bool + UpLimit int64 + DownLimit int64 } func saveSettings(localHost string, localPort int, remote string, delay time.Duration, - protocol string, stat, quiet bool) { + protocol string, stat, quiet bool, upLimit, downLimit int64) { if localHost != "" { settings.LocalHost = localHost } @@ -27,4 +29,6 @@ func saveSettings(localHost string, localPort int, remote string, delay time.Dur settings.Protocol = protocol settings.Stat = stat settings.Quiet = quiet + settings.UpLimit = upLimit + settings.DownLimit = downLimit } diff --git a/tproxy.go b/tproxy.go index f994190..f2d0188 100644 --- a/tproxy.go +++ b/tproxy.go @@ -20,6 +20,8 @@ func main() { stat = flag.Bool("s", false, "Enable statistics") quiet = flag.Bool("q", false, "Quiet mode, only prints connection open/close and stats, default false") + upLimit = flag.Int64("U", 0, "Upward speed limit(Bytes/second)") + downLimit = flag.Int64("D", 0, "Downward speed limit(Bytes/second)") ) if len(os.Args) <= 1 { @@ -28,7 +30,7 @@ func main() { } flag.Parse() - saveSettings(*localHost, *localPort, *remote, *delay, *protocol, *stat, *quiet) + saveSettings(*localHost, *localPort, *remote, *delay, *protocol, *stat, *quiet, *upLimit, *downLimit) if len(settings.Remote) == 0 { fmt.Fprintln(os.Stderr, color.HiRedString("[x] Remote target required"))