Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0a4c2ea815 | |||
| f933e8019e | |||
| a198407ec3 | |||
| d70ad6c956 | |||
| ede08cf685 | |||
| 2cc71fb749 | |||
| 6c74582ea9 | |||
| 6644dc814f | |||
| a705995975 | |||
| 11fecf34d2 | |||
| 1429635b95 | |||
| 065fae8ee7 |
@@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anacrolix/missinggo"
|
||||
@@ -24,8 +23,8 @@ type Conn struct {
|
||||
connKey connKey
|
||||
|
||||
// Data waiting to be Read.
|
||||
readBuf []byte
|
||||
readCond sync.Cond
|
||||
readBuf []byte
|
||||
readBufNotEmpty missinggo.Event
|
||||
|
||||
socket *Socket
|
||||
remoteSocketAddr net.Addr
|
||||
@@ -36,12 +35,13 @@ type Conn struct {
|
||||
|
||||
sentSyn bool
|
||||
synAcked bool
|
||||
gotFin missinggo.Flag
|
||||
wroteFin bool
|
||||
gotFin missinggo.Event
|
||||
wroteFin missinggo.Event
|
||||
finAcked bool
|
||||
err error
|
||||
closed missinggo.Flag
|
||||
destroyed missinggo.Flag
|
||||
closed missinggo.Event
|
||||
destroyed missinggo.Event
|
||||
canWrite missinggo.Event
|
||||
|
||||
unackedSends []*send
|
||||
// Inbound payloads, the first is ack_nr+1.
|
||||
@@ -83,7 +83,7 @@ func (c *Conn) sendPendingSendStateTimerCallback() {
|
||||
|
||||
// Send a state packet, if one is needed.
|
||||
func (c *Conn) sendPendingState() {
|
||||
if c.destroyed.Get() {
|
||||
if c.destroyed.IsSet() {
|
||||
c.sendReset()
|
||||
} else {
|
||||
c.sendState()
|
||||
@@ -196,7 +196,7 @@ func (c *Conn) write(_type st, connID uint16, payload []byte, seqNr uint16) (n i
|
||||
default:
|
||||
panic(_type)
|
||||
}
|
||||
if c.wroteFin {
|
||||
if c.wroteFin.IsSet() {
|
||||
panic("can't write after fin")
|
||||
}
|
||||
if len(payload) > maxPayloadSize {
|
||||
@@ -224,6 +224,7 @@ func (c *Conn) write(_type st, connID uint16, payload []byte, seqNr uint16) (n i
|
||||
send.resendTimer = time.AfterFunc(c.resendTimeout(), send.timeoutResend)
|
||||
c.unackedSends = append(c.unackedSends, send)
|
||||
c.cur_window += send.payloadSize
|
||||
c.updateCanWrite()
|
||||
c.seq_nr++
|
||||
return
|
||||
}
|
||||
@@ -242,7 +243,7 @@ func (c *Conn) latency() (ret time.Duration) {
|
||||
|
||||
func (c *Conn) numUnackedSends() (num int) {
|
||||
for _, s := range c.unackedSends {
|
||||
if !s.acked {
|
||||
if !s.acked.IsSet() {
|
||||
num++
|
||||
}
|
||||
}
|
||||
@@ -273,6 +274,7 @@ func (c *Conn) ack(nr uint16) {
|
||||
latency, first := s.Ack()
|
||||
if first {
|
||||
c.cur_window -= s.payloadSize
|
||||
c.updateCanWrite()
|
||||
c.latencies = append(c.latencies, latency)
|
||||
if len(c.latencies) > 10 {
|
||||
c.latencies = c.latencies[len(c.latencies)-10:]
|
||||
@@ -280,15 +282,15 @@ func (c *Conn) ack(nr uint16) {
|
||||
}
|
||||
// Trim sends that aren't needed anymore.
|
||||
for len(c.unackedSends) != 0 {
|
||||
if !c.unackedSends[0].acked {
|
||||
if !c.unackedSends[0].acked.IsSet() {
|
||||
// Can't trim unacked sends any further.
|
||||
return
|
||||
}
|
||||
// Trim the front of the unacked sends.
|
||||
c.unackedSends = c.unackedSends[1:]
|
||||
c.updateCanWrite()
|
||||
c.lastAck++
|
||||
}
|
||||
cond.Broadcast()
|
||||
}
|
||||
|
||||
func (c *Conn) ackTo(nr uint16) {
|
||||
@@ -327,7 +329,7 @@ func (c *Conn) ackSkipped(seqNr uint16) {
|
||||
return
|
||||
}
|
||||
send.acksSkipped++
|
||||
if send.acked {
|
||||
if send.acked.IsSet() {
|
||||
return
|
||||
}
|
||||
switch send.acksSkipped {
|
||||
@@ -352,7 +354,7 @@ func (c *Conn) receivePacketTimeoutCallback() {
|
||||
}
|
||||
|
||||
func (c *Conn) lazyDestroy() {
|
||||
if c.wroteFin && len(c.unackedSends) <= 1 && (c.gotFin.Get() || c.closed.Get()) {
|
||||
if c.wroteFin.IsSet() && len(c.unackedSends) <= 1 && (c.gotFin.IsSet() || c.closed.IsSet()) {
|
||||
c.destroy(errors.New("lazily destroyed"))
|
||||
}
|
||||
}
|
||||
@@ -360,7 +362,6 @@ func (c *Conn) lazyDestroy() {
|
||||
func (c *Conn) processDelivery(h header, payload []byte) {
|
||||
deliveriesProcessed.Add(1)
|
||||
defer c.lazyDestroy()
|
||||
defer cond.Broadcast()
|
||||
c.assertHeader(h)
|
||||
c.peerWndSize = h.WndSize
|
||||
c.applyAcks(h)
|
||||
@@ -379,6 +380,7 @@ func (c *Conn) processDelivery(h header, payload []byte) {
|
||||
return
|
||||
}
|
||||
c.synAcked = true
|
||||
c.updateCanWrite()
|
||||
c.ack_nr = h.SeqNr - 1
|
||||
return
|
||||
}
|
||||
@@ -448,17 +450,21 @@ func (c *Conn) assertHeader(h header) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) updateReadBufNotEmpty() {
|
||||
c.readBufNotEmpty.SetBool(len(c.readBuf) != 0)
|
||||
}
|
||||
|
||||
func (c *Conn) processInbound() {
|
||||
// Consume consecutive next packets.
|
||||
for !c.gotFin.Get() && len(c.inbound) > 0 && c.inbound[0].seen && len(c.readBuf) < readBufferLen {
|
||||
for !c.gotFin.IsSet() && len(c.inbound) > 0 && c.inbound[0].seen && len(c.readBuf) < readBufferLen {
|
||||
c.ack_nr++
|
||||
p := c.inbound[0]
|
||||
c.inbound = c.inbound[1:]
|
||||
c.inboundWnd -= uint32(len(p.data))
|
||||
c.readBuf = append(c.readBuf, p.data...)
|
||||
c.readCond.Broadcast()
|
||||
c.updateReadBufNotEmpty()
|
||||
if p.Type == stFin {
|
||||
c.gotFin.Set(true)
|
||||
c.gotFin.Set()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -468,9 +474,7 @@ func (c *Conn) waitAck(seq uint16) {
|
||||
if send == nil {
|
||||
return
|
||||
}
|
||||
for !(send.acked || c.destroyed.Get()) {
|
||||
cond.Wait()
|
||||
}
|
||||
missinggo.WaitEvents(&mu, &send.acked, &c.destroyed)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -489,23 +493,21 @@ func (c *Conn) connect() (err error) {
|
||||
err = c.err
|
||||
}
|
||||
c.synAcked = true
|
||||
cond.Broadcast()
|
||||
c.updateCanWrite()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) writeFin() {
|
||||
if c.wroteFin {
|
||||
if c.wroteFin.IsSet() {
|
||||
return
|
||||
}
|
||||
c.write(stFin, c.send_id, nil, c.seq_nr)
|
||||
c.wroteFin = true
|
||||
cond.Broadcast()
|
||||
c.wroteFin.Set()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) destroy(reason error) {
|
||||
c.destroyed.Set(true)
|
||||
cond.Broadcast()
|
||||
c.destroyed.Set()
|
||||
if c.err == nil {
|
||||
c.err = reason
|
||||
}
|
||||
@@ -515,8 +517,7 @@ func (c *Conn) destroy(reason error) {
|
||||
func (c *Conn) Close() (err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
c.closed.Set(true)
|
||||
cond.Broadcast()
|
||||
c.closed.Set()
|
||||
c.writeFin()
|
||||
c.lazyDestroy()
|
||||
return
|
||||
@@ -532,27 +533,33 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
for {
|
||||
n = copy(b, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
c.updateReadBufNotEmpty()
|
||||
if n != 0 {
|
||||
// Inbound packets are backed up when the read buffer is too big.
|
||||
c.processInbound()
|
||||
return
|
||||
}
|
||||
if c.gotFin.Get() || c.closed.Get() {
|
||||
if c.gotFin.IsSet() || c.closed.IsSet() {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
if c.destroyed.Get() {
|
||||
if c.destroyed.IsSet() {
|
||||
if c.err == nil {
|
||||
panic("closed without receiving fin, and no error")
|
||||
}
|
||||
err = c.err
|
||||
return
|
||||
}
|
||||
if c.connDeadlines.read.passed.Get() {
|
||||
if c.connDeadlines.read.passed.IsSet() {
|
||||
err = errTimeout
|
||||
return
|
||||
}
|
||||
c.readCond.Wait()
|
||||
missinggo.WaitEvents(&mu,
|
||||
&c.gotFin,
|
||||
&c.closed,
|
||||
&c.destroyed,
|
||||
&c.connDeadlines.read.passed,
|
||||
&c.readBufNotEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,27 +571,31 @@ func (c *Conn) String() string {
|
||||
return fmt.Sprintf("<UTPConn %s-%s (%d)>", c.LocalAddr(), c.RemoteAddr(), c.recv_id)
|
||||
}
|
||||
|
||||
func (c *Conn) updateCanWrite() {
|
||||
c.canWrite.SetBool(c.synAcked &&
|
||||
len(c.unackedSends) < maxUnackedSends &&
|
||||
c.cur_window <= c.peerWndSize)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for len(p) != 0 {
|
||||
if c.wroteFin || c.closed.Get() {
|
||||
if c.wroteFin.IsSet() || c.closed.IsSet() {
|
||||
err = errClosed
|
||||
return
|
||||
}
|
||||
if c.destroyed.Get() {
|
||||
if c.destroyed.IsSet() {
|
||||
err = c.err
|
||||
return
|
||||
}
|
||||
if c.connDeadlines.write.passed.Get() {
|
||||
if c.connDeadlines.write.passed.IsSet() {
|
||||
err = errTimeout
|
||||
return
|
||||
}
|
||||
// If peerWndSize is 0, we still want to send something, so don't
|
||||
// block until we exceed it.
|
||||
if c.synAcked &&
|
||||
len(c.unackedSends) < maxUnackedSends &&
|
||||
c.cur_window <= c.peerWndSize {
|
||||
if c.canWrite.IsSet() {
|
||||
var n1 int
|
||||
n1, err = c.write(stData, c.send_id, p, c.seq_nr)
|
||||
n += n1
|
||||
@@ -597,7 +608,12 @@ func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
p = p[n1:]
|
||||
continue
|
||||
}
|
||||
cond.Wait()
|
||||
missinggo.WaitEvents(&mu,
|
||||
&c.wroteFin,
|
||||
&c.closed,
|
||||
&c.destroyed,
|
||||
&c.connDeadlines.write.passed,
|
||||
&c.canWrite)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
+24
-14
@@ -8,32 +8,42 @@ import (
|
||||
|
||||
type deadline struct {
|
||||
t time.Time
|
||||
passed missinggo.Flag
|
||||
passed missinggo.Event
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func (me *deadline) set(t time.Time) {
|
||||
me.passed.Set(false)
|
||||
me.t = t
|
||||
me.timer = time.AfterFunc(0, me.callback)
|
||||
me.passed.Clear()
|
||||
if me.timer != nil {
|
||||
me.timer.Stop()
|
||||
}
|
||||
me.update()
|
||||
}
|
||||
|
||||
func (me *deadline) update() {
|
||||
if me.t.IsZero() {
|
||||
return
|
||||
}
|
||||
if time.Now().Before(me.t) {
|
||||
if me.timer == nil {
|
||||
me.timer = time.AfterFunc(me.t.Sub(time.Now()), me.callback)
|
||||
} else {
|
||||
me.timer.Reset(me.t.Sub(time.Now()))
|
||||
}
|
||||
return
|
||||
}
|
||||
me.passed.Set()
|
||||
}
|
||||
|
||||
func (me *deadline) callback() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if me.t.IsZero() {
|
||||
return
|
||||
}
|
||||
if time.Now().Before(me.t) {
|
||||
me.timer.Reset(me.t.Sub(time.Now()))
|
||||
return
|
||||
}
|
||||
me.passed.Set(true)
|
||||
cond.Broadcast()
|
||||
me.update()
|
||||
}
|
||||
|
||||
// This is embedded in Conn to provide deadline methods for net.Conn. It
|
||||
// tickles global mu and cond as required.
|
||||
// This is embedded in Conn and Socket to provide deadline methods for
|
||||
// net.Conn.
|
||||
type connDeadlines struct {
|
||||
read, write deadline
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
type send struct {
|
||||
acked bool // Closed with Conn lock.
|
||||
acked missinggo.Event
|
||||
payloadSize uint32
|
||||
started missinggo.MonotonicTime
|
||||
_type st
|
||||
@@ -25,12 +25,11 @@ type send struct {
|
||||
// first is true if this is the first time the send is acked. latency is
|
||||
// calculated for the first ack.
|
||||
func (s *send) Ack() (latency time.Duration, first bool) {
|
||||
first = !s.acked
|
||||
first = !s.acked.IsSet()
|
||||
if first {
|
||||
latency = missinggo.MonotonicSince(s.started)
|
||||
}
|
||||
s.acked = true
|
||||
cond.Broadcast()
|
||||
s.acked.Set()
|
||||
if s.resendTimer != nil {
|
||||
s.resendTimer.Stop()
|
||||
s.resendTimer = nil
|
||||
@@ -49,7 +48,7 @@ func (s *send) timeoutResend() {
|
||||
s.timedOut()
|
||||
return
|
||||
}
|
||||
if s.acked || s.conn.destroyed.Get() {
|
||||
if s.acked.IsSet() || s.conn.destroyed.IsSet() {
|
||||
return
|
||||
}
|
||||
rt := s.conn.resendTimeout()
|
||||
@@ -59,7 +58,7 @@ func (s *send) timeoutResend() {
|
||||
}
|
||||
|
||||
func (s *send) resend() {
|
||||
if s.acked {
|
||||
if s.acked.IsSet() {
|
||||
return
|
||||
}
|
||||
err := s.conn.send(s._type, s.connID, s.payload, s.seqNr)
|
||||
|
||||
@@ -274,15 +274,8 @@ func (s *Socket) newConn(addr net.Addr) (c *Conn) {
|
||||
remoteSocketAddr: addr,
|
||||
created: time.Now(),
|
||||
}
|
||||
c.readCond.L = &mu
|
||||
c.sendPendingSendSendStateTimer = missinggo.StoppedFuncTimer(c.sendPendingSendStateTimerCallback)
|
||||
c.packetReadTimeoutTimer = time.AfterFunc(packetReadTimeout, c.receivePacketTimeoutCallback)
|
||||
missinggo.AddCondToFlags(
|
||||
&c.readCond,
|
||||
&c.destroyed,
|
||||
&c.gotFin,
|
||||
&c.closed,
|
||||
&c.connDeadlines.read.passed)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -349,6 +342,9 @@ func (s *Socket) DialTimeout(addr string, timeout time.Duration) (nc net.Conn, e
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
c.updateCanWrite()
|
||||
mu.Unlock()
|
||||
nc = pproffd.WrapNetConn(c)
|
||||
return
|
||||
}
|
||||
@@ -389,9 +385,7 @@ func (s *Socket) backlogChanged() {
|
||||
|
||||
func (s *Socket) nextSyn() (syn syn, err error) {
|
||||
for {
|
||||
mu.Unlock()
|
||||
missinggo.WaitEvents(&mu, &s.closed, &s.backlogNotEmpty, &s.destroyed)
|
||||
mu.Lock()
|
||||
if s.closed.IsSet() {
|
||||
err = errClosed
|
||||
return
|
||||
@@ -420,6 +414,7 @@ func (s *Socket) ackSyn(syn syn) (c *Conn, ok bool) {
|
||||
c.ack_nr = syn.seq_nr
|
||||
c.sentSyn = true
|
||||
c.synAcked = true
|
||||
c.updateCanWrite()
|
||||
if !s.registerConn(c.recv_id, resolvedAddrStr(syn.addr), c) {
|
||||
// SYN that triggered this accept duplicates existing connection.
|
||||
// Ack again in case the SYN was a resend.
|
||||
@@ -446,6 +441,7 @@ func (s *Socket) Accept() (net.Conn, error) {
|
||||
}
|
||||
c, ok := s.ackSyn(syn)
|
||||
if ok {
|
||||
c.updateCanWrite()
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
@@ -490,15 +486,29 @@ func (s *Socket) LocalAddr() net.Addr {
|
||||
}
|
||||
|
||||
func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
read, ok := <-s.unusedReads
|
||||
if !ok {
|
||||
err = io.EOF
|
||||
select {
|
||||
case read, ok := <-s.unusedReads:
|
||||
if !ok {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
n = copy(p, read.data)
|
||||
addr = read.from
|
||||
return
|
||||
case <-s.connDeadlines.read.passed.LockedChan(&mu):
|
||||
err = errTimeout
|
||||
return
|
||||
}
|
||||
n = copy(p, read.data)
|
||||
addr = read.from
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Socket) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
func (s *Socket) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
mu.Lock()
|
||||
if s.connDeadlines.write.passed.IsSet() {
|
||||
err = errTimeout
|
||||
}
|
||||
mu.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.pc.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
+153
@@ -1,9 +1,16 @@
|
||||
package utp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/anacrolix/missinggo"
|
||||
"github.com/anacrolix/missinggo/inproc"
|
||||
"github.com/bradfitz/iter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -17,3 +24,149 @@ func TestAcceptOnDestroyedSocket(t *testing.T) {
|
||||
_, err = s.Accept()
|
||||
assert.Contains(t, err.Error(), "use of closed network connection")
|
||||
}
|
||||
|
||||
func TestSocketDeadlines(t *testing.T) {
|
||||
s, err := NewSocket("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
assert.NoError(t, s.SetReadDeadline(time.Now()))
|
||||
_, _, err = s.ReadFrom(nil)
|
||||
assert.Equal(t, errTimeout, err)
|
||||
assert.NoError(t, s.SetWriteDeadline(time.Now()))
|
||||
_, err = s.WriteTo(nil, nil)
|
||||
assert.Equal(t, errTimeout, err)
|
||||
assert.NoError(t, s.SetDeadline(time.Time{}))
|
||||
assert.NoError(t, s.Close())
|
||||
}
|
||||
|
||||
func TestSaturateSocketConnIDs(t *testing.T) {
|
||||
s, err := NewSocket("inproc", "")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
var acceptedConns, dialedConns []net.Conn
|
||||
for range iter.N(500) {
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
c, err := s.Accept()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
acceptedConns = append(acceptedConns, c)
|
||||
close(accepted)
|
||||
}()
|
||||
c, err := s.Dial(s.Addr().String())
|
||||
require.NoError(t, err)
|
||||
dialedConns = append(dialedConns, c)
|
||||
<-accepted
|
||||
}
|
||||
t.Logf("%d dialed conns, %d accepted", len(dialedConns), len(acceptedConns))
|
||||
for i := range iter.N(len(dialedConns)) {
|
||||
data := []byte(fmt.Sprintf("%7d", i))
|
||||
dc := dialedConns[i]
|
||||
n, err := dc.Write(data)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 7, n)
|
||||
require.NoError(t, dc.Close())
|
||||
var b [8]byte
|
||||
ac := acceptedConns[i]
|
||||
n, err = ac.Read(b[:])
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 7, n)
|
||||
require.EqualValues(t, data, b[:n])
|
||||
n, err = ac.Read(b[:])
|
||||
require.EqualValues(t, 0, n)
|
||||
require.EqualValues(t, io.EOF, err)
|
||||
ac.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTPRawConn(t *testing.T) {
|
||||
l, err := NewSocket("inproc", "")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
_, err := l.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Connect a UTP peer to see if the RawConn will still work.
|
||||
log.Print("dialing")
|
||||
utpPeer := func() net.Conn {
|
||||
s, _ := NewSocket("inproc", "")
|
||||
defer s.Close()
|
||||
ret, err := s.Dial(fmt.Sprintf("localhost:%d", missinggo.AddrPort(l.Addr())))
|
||||
require.NoError(t, err)
|
||||
return ret
|
||||
}()
|
||||
log.Print("dial returned")
|
||||
if err != nil {
|
||||
t.Fatalf("error dialing utp listener: %s", err)
|
||||
}
|
||||
defer utpPeer.Close()
|
||||
peer, err := inproc.ListenPacket("inproc", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer peer.Close()
|
||||
|
||||
msgsReceived := 0
|
||||
const N = 500 // How many messages to send.
|
||||
readerStopped := make(chan struct{})
|
||||
// The reader goroutine.
|
||||
go func() {
|
||||
defer close(readerStopped)
|
||||
b := make([]byte, 500)
|
||||
for i := 0; i < N; i++ {
|
||||
n, _, err := l.ReadFrom(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading from raw conn: %s", err)
|
||||
}
|
||||
msgsReceived++
|
||||
var d int
|
||||
fmt.Sscan(string(b[:n]), &d)
|
||||
if d != i {
|
||||
log.Printf("got wrong number: expected %d, got %d", i, d)
|
||||
}
|
||||
}
|
||||
}()
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", missinggo.AddrPort(l.Addr())))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i < N; i++ {
|
||||
_, err := peer.WriteTo([]byte(fmt.Sprintf("%d", i)), udpAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
select {
|
||||
case <-readerStopped:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reader timed out")
|
||||
}
|
||||
if msgsReceived != N {
|
||||
t.Fatalf("messages received: %d", msgsReceived)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptGone(t *testing.T) {
|
||||
s, err := NewSocket("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
_, err = DialTimeout(s.Addr().String(), time.Millisecond)
|
||||
require.Error(t, err)
|
||||
// Will succeed because we don't signal that we give up dialing, or check
|
||||
// that the handshake is completed before returning the new Conn.
|
||||
c, err := s.Accept()
|
||||
require.NoError(t, err)
|
||||
defer c.Close()
|
||||
err = c.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
_, err = c.Read(nil)
|
||||
require.EqualError(t, err, "i/o timeout")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package utp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getTCPConnectionPair() (net.Conn, net.Conn, error) {
|
||||
lst, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var conn0 net.Conn
|
||||
var err0 error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
conn0, err0 = lst.Accept()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
conn1, err := net.Dial("tcp", lst.Addr().String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
<-done
|
||||
if err0 != nil {
|
||||
return nil, nil, err0
|
||||
}
|
||||
return conn0, conn1, nil
|
||||
}
|
||||
|
||||
func getUTPConnectionPair() (net.Conn, net.Conn, error) {
|
||||
lst, err := NewSocket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer lst.Close()
|
||||
|
||||
var conn0 net.Conn
|
||||
var err0 error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
conn0, err0 = lst.Accept()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
conn1, err := Dial(lst.Addr().String())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
<-done
|
||||
if err0 != nil {
|
||||
return nil, nil, err0
|
||||
}
|
||||
|
||||
return conn0, conn1, nil
|
||||
}
|
||||
|
||||
func benchConnPair(b *testing.B, c0, c1 net.Conn) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(128 << 10)
|
||||
b.ResetTimer()
|
||||
|
||||
request := make([]byte, 52)
|
||||
response := make([]byte, (128<<10)+8)
|
||||
|
||||
pair := []net.Conn{c0, c1}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if i%2 == 0 {
|
||||
pair[0] = c0
|
||||
pair[1] = c1
|
||||
} else {
|
||||
pair[0] = c1
|
||||
pair[1] = c0
|
||||
}
|
||||
|
||||
if _, err := pair[0].Write(request); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := pair[1].Read(request[:8]); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := pair[1].Read(request[8:]); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := pair[1].Write(response); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := pair[0].Read(response[:8]); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := pair[0].Read(response[8:]); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSyncthingTCP(b *testing.B) {
|
||||
conn0, conn1, err := getTCPConnectionPair()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
defer conn0.Close()
|
||||
defer conn1.Close()
|
||||
|
||||
benchConnPair(b, conn0, conn1)
|
||||
}
|
||||
|
||||
func BenchmarkSyncthingUDPUTP(b *testing.B) {
|
||||
conn0, conn1, err := getUTPConnectionPair()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
defer conn0.Close()
|
||||
defer conn1.Close()
|
||||
|
||||
benchConnPair(b, conn0, conn1)
|
||||
}
|
||||
|
||||
func BenchmarkSyncthingInprocUTP(b *testing.B) {
|
||||
c0, c1 := connPair()
|
||||
defer c0.Close()
|
||||
defer c1.Close()
|
||||
benchConnPair(b, c0, c1)
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pprofsync "github.com/anacrolix/sync"
|
||||
@@ -106,7 +105,6 @@ type syn struct {
|
||||
|
||||
var (
|
||||
mu pprofsync.RWMutex
|
||||
cond = sync.Cond{L: &mu}
|
||||
sockets = map[*Socket]struct{}{}
|
||||
logLevel = 0
|
||||
artificialPacketDropChance = 0.0
|
||||
|
||||
+9
-151
@@ -3,7 +3,6 @@ package utp
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
@@ -16,13 +15,15 @@ import (
|
||||
"time"
|
||||
|
||||
_ "github.com/anacrolix/envpprof"
|
||||
"github.com/anacrolix/missinggo"
|
||||
"github.com/anacrolix/missinggo/inproc"
|
||||
"github.com/bradfitz/iter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetFlags(log.Flags() | log.Lshortfile)
|
||||
}
|
||||
|
||||
func setDefaultTestingDurations() {
|
||||
writeTimeout = 1 * time.Second
|
||||
initialLatency = 10 * time.Millisecond
|
||||
@@ -101,79 +102,6 @@ func TestMinMaxHeaderType(t *testing.T) {
|
||||
require.Equal(t, stSyn, stMax)
|
||||
}
|
||||
|
||||
func TestUTPRawConn(t *testing.T) {
|
||||
l, err := NewSocket("inproc", "")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
go func() {
|
||||
for {
|
||||
_, err := l.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Connect a UTP peer to see if the RawConn will still work.
|
||||
log.Print("dialing")
|
||||
utpPeer := func() net.Conn {
|
||||
s, _ := NewSocket("inproc", "")
|
||||
defer s.Close()
|
||||
ret, err := s.Dial(fmt.Sprintf("localhost:%d", missinggo.AddrPort(l.Addr())))
|
||||
require.NoError(t, err)
|
||||
return ret
|
||||
}()
|
||||
log.Print("dial returned")
|
||||
if err != nil {
|
||||
t.Fatalf("error dialing utp listener: %s", err)
|
||||
}
|
||||
defer utpPeer.Close()
|
||||
peer, err := inproc.ListenPacket("inproc", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer peer.Close()
|
||||
|
||||
msgsReceived := 0
|
||||
const N = 500 // How many messages to send.
|
||||
readerStopped := make(chan struct{})
|
||||
// The reader goroutine.
|
||||
go func() {
|
||||
defer close(readerStopped)
|
||||
b := make([]byte, 500)
|
||||
for i := 0; i < N; i++ {
|
||||
n, _, err := l.ReadFrom(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading from raw conn: %s", err)
|
||||
}
|
||||
msgsReceived++
|
||||
var d int
|
||||
fmt.Sscan(string(b[:n]), &d)
|
||||
if d != i {
|
||||
log.Printf("got wrong number: expected %d, got %d", i, d)
|
||||
}
|
||||
}
|
||||
}()
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", missinggo.AddrPort(l.Addr())))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i < N; i++ {
|
||||
_, err := peer.WriteTo([]byte(fmt.Sprintf("%d", i)), udpAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
select {
|
||||
case <-readerStopped:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reader timed out")
|
||||
}
|
||||
if msgsReceived != N {
|
||||
t.Fatalf("messages received: %d", msgsReceived)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnReadDeadline(t *testing.T) {
|
||||
t.Parallel()
|
||||
ls, _ := NewSocket("udp", "localhost:0")
|
||||
@@ -193,8 +121,8 @@ func TestConnReadDeadline(t *testing.T) {
|
||||
_, err := c.Read(nil)
|
||||
require.Equal(t, errTimeout, err)
|
||||
// The deadline has passed.
|
||||
if !time.Now().After(dl) {
|
||||
t.FailNow()
|
||||
if time.Now().Before(dl) {
|
||||
t.Fatal("deadline hasn't passed")
|
||||
}
|
||||
// Returns timeout on subsequent read.
|
||||
_, err = c.Read(nil)
|
||||
@@ -209,7 +137,7 @@ func TestConnReadDeadline(t *testing.T) {
|
||||
select {
|
||||
case <-readReturned:
|
||||
// Read returned but shouldn't have.
|
||||
t.FailNow()
|
||||
t.Fatal("read returned")
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
c.Close()
|
||||
@@ -451,23 +379,6 @@ func TestConnCloseUnclosedSocket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptGone(t *testing.T) {
|
||||
s, err := NewSocket("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
_, err = DialTimeout(s.Addr().String(), time.Millisecond)
|
||||
require.Error(t, err)
|
||||
// Will succeed because we don't signal that we give up dialing, or check
|
||||
// that the handshake is completed before returning the new Conn.
|
||||
c, err := s.Accept()
|
||||
require.NoError(t, err)
|
||||
defer c.Close()
|
||||
err = c.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||
require.NoError(t, err)
|
||||
_, err = c.Read(nil)
|
||||
require.EqualError(t, err, "i/o timeout")
|
||||
}
|
||||
|
||||
func TestPacketReadTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
a, b := connPair()
|
||||
@@ -537,52 +448,6 @@ func TestAcceptReturnsAfterClose(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetFlags(log.Flags() | log.Lshortfile)
|
||||
}
|
||||
|
||||
func TestSaturateSocketConnIDs(t *testing.T) {
|
||||
s, err := NewSocket("inproc", "")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
var acceptedConns, dialedConns []net.Conn
|
||||
for range iter.N(500) {
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
c, err := s.Accept()
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
acceptedConns = append(acceptedConns, c)
|
||||
close(accepted)
|
||||
}()
|
||||
c, err := s.Dial(s.Addr().String())
|
||||
require.NoError(t, err)
|
||||
dialedConns = append(dialedConns, c)
|
||||
<-accepted
|
||||
}
|
||||
t.Logf("%d dialed conns, %d accepted", len(dialedConns), len(acceptedConns))
|
||||
for i := range iter.N(len(dialedConns)) {
|
||||
data := []byte(fmt.Sprintf("%7d", i))
|
||||
dc := dialedConns[i]
|
||||
n, err := dc.Write(data)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 7, n)
|
||||
require.NoError(t, dc.Close())
|
||||
var b [8]byte
|
||||
ac := acceptedConns[i]
|
||||
n, err = ac.Read(b[:])
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 7, n)
|
||||
require.EqualValues(t, data, b[:n])
|
||||
n, err = ac.Read(b[:])
|
||||
require.EqualValues(t, 0, n)
|
||||
require.EqualValues(t, io.EOF, err)
|
||||
ac.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClose(t *testing.T) {
|
||||
a, b := connPair()
|
||||
defer a.Close()
|
||||
@@ -618,15 +483,6 @@ func TestWriteUnderlyingPacketConnClosed(t *testing.T) {
|
||||
assert.EqualError(t, err, "Socket destroyed")
|
||||
}
|
||||
|
||||
func TestSetSocketDeadlines(t *testing.T) {
|
||||
s, err := NewSocket("udp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, s.SetReadDeadline(time.Now().Add(time.Second)))
|
||||
assert.NoError(t, s.SetWriteDeadline(time.Now().Add(time.Second)))
|
||||
assert.NoError(t, s.SetDeadline(time.Time{}))
|
||||
assert.NoError(t, s.Close())
|
||||
}
|
||||
|
||||
func TestFillBuffers(t *testing.T) {
|
||||
a, b := connPair()
|
||||
defer b.Close()
|
||||
@@ -684,6 +540,8 @@ func BenchmarkEchoLongBuffer(tb *testing.B) {
|
||||
n, err := io.ReadFull(rand.Reader, pristine)
|
||||
require.EqualValues(tb, len(pristine), n)
|
||||
require.NoError(tb, err)
|
||||
tb.SetBytes(int64(len(pristine)))
|
||||
tb.ResetTimer()
|
||||
for range iter.N(tb.N) {
|
||||
func() {
|
||||
a, b := connPair()
|
||||
|
||||
Reference in New Issue
Block a user