Files
utp/utp_test.go

570 lines
13 KiB
Go

package utp
import (
"bytes"
"crypto/rand"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"runtime"
"sync"
"testing"
"time"
_ "github.com/anacrolix/envpprof"
"github.com/bradfitz/iter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
log.SetFlags(log.Flags() | log.Lshortfile)
}
func setDefaultTestingDurations() {
writeTimeout = 1 * time.Second
initialLatency = 10 * time.Millisecond
packetReadTimeout = 2 * time.Second
}
func TestUTPPingPong(t *testing.T) {
defer goroutineLeakCheck(t)()
s, err := NewSocket("udp", "localhost:0")
require.NoError(t, err)
defer s.Close()
pingerClosed := make(chan struct{})
go func() {
defer close(pingerClosed)
b, err := Dial(s.Addr().String())
require.NoError(t, err)
defer b.Close()
n, err := b.Write([]byte("ping"))
require.NoError(t, err)
require.EqualValues(t, 4, n)
buf := make([]byte, 4)
b.Read(buf)
require.EqualValues(t, "pong", buf)
log.Printf("got pong")
}()
a, err := s.Accept()
require.NoError(t, err)
defer a.Close()
log.Printf("accepted %s", a)
buf := make([]byte, 42)
n, err := a.Read(buf)
require.NoError(t, err)
require.EqualValues(t, "ping", buf[:n])
log.Print("got ping")
n, err = a.Write([]byte("pong"))
require.NoError(t, err)
require.Equal(t, 4, n)
log.Print("waiting for pinger to close")
<-pingerClosed
}
func goroutineLeakCheck(t testing.TB) func() {
if !testing.Verbose() {
return func() {}
}
numStart := runtime.NumGoroutine()
return func() {
var numNow int
for range iter.N(1) {
numNow = runtime.NumGoroutine()
if numNow == numStart {
return
}
time.Sleep(10 * time.Millisecond)
}
// I'd print stacks, or treat this as fatal, but I think
// runtime.NumGoroutine is including system routines for which we are
// not provided the stacks, and are spawned unpredictably.
t.Logf("have %d goroutines, started with %d", numNow, numStart)
}
}
func TestDialTimeout(t *testing.T) {
defer goroutineLeakCheck(t)()
s, _ := NewSocket("udp", "localhost:0")
defer s.Close()
conn, err := DialTimeout(s.Addr().String(), 10*time.Millisecond)
if err == nil {
conn.Close()
t.Fatal("expected timeout")
}
t.Log(err)
}
func TestMinMaxHeaderType(t *testing.T) {
require.Equal(t, stSyn, stMax)
}
func TestConnReadDeadline(t *testing.T) {
t.Parallel()
ls, _ := NewSocket("udp", "localhost:0")
defer ls.Close()
ds, _ := NewSocket("udp", "localhost:0")
defer ds.Close()
dcReadErr := make(chan error)
go func() {
c, _ := ds.Dial(ls.Addr().String())
defer c.Close()
_, err := c.Read(nil)
dcReadErr <- err
}()
c, _ := ls.Accept()
dl := time.Now().Add(time.Millisecond)
c.SetReadDeadline(dl)
_, err := c.Read(nil)
require.Equal(t, errTimeout, err)
// The deadline has passed.
if time.Now().Before(dl) {
t.Fatal("deadline hasn't passed")
}
// Returns timeout on subsequent read.
_, err = c.Read(nil)
require.Equal(t, errTimeout, err)
// Disable the deadline.
c.SetReadDeadline(time.Time{})
readReturned := make(chan struct{})
go func() {
c.Read(nil)
close(readReturned)
}()
select {
case <-readReturned:
// Read returned but shouldn't have.
t.Fatal("read returned")
case <-time.After(time.Millisecond):
}
c.Close()
if err := <-dcReadErr; err != io.EOF {
t.Fatalf("dial conn read returned %s", err)
}
select {
case <-readReturned:
case <-time.After(time.Millisecond):
t.Fatal("read should return after Conn is closed")
}
}
func connectSelfLots(n int, t testing.TB) {
defer goroutineLeakCheck(t)()
s, err := NewSocket("udp", "localhost:0")
require.NoError(t, err)
go func() {
for range iter.N(n) {
c, err := s.Accept()
if err != nil {
log.Fatal(err)
}
defer c.Close()
}
}()
dialErr := make(chan error)
connCh := make(chan net.Conn)
dialSema := make(chan struct{}, backlog)
for range iter.N(n) {
go func() {
dialSema <- struct{}{}
c, err := s.Dial(s.Addr().String())
<-dialSema
if err != nil {
dialErr <- err
return
}
connCh <- c
}()
}
conns := make([]net.Conn, 0, n)
for range iter.N(n) {
select {
case c := <-connCh:
conns = append(conns, c)
case err := <-dialErr:
t.Fatal(err)
}
}
for _, c := range conns {
if c != nil {
c.Close()
}
}
sleepWhile(&mu, func() bool { return len(s.conns) != 0 })
s.Close()
}
// Connect to ourself heaps.
func TestConnectSelf(t *testing.T) {
// A rough guess says that at worst, I can only have 0x10000/3 connections
// to the same socket, due to fragmentation in the assigned connection
// IDs.
connectSelfLots(0x100, t)
}
func BenchmarkConnectSelf(b *testing.B) {
for range iter.N(b.N) {
connectSelfLots(2, b)
}
}
func BenchmarkNewCloseSocket(b *testing.B) {
for range iter.N(b.N) {
s, err := NewSocket("udp", "localhost:0")
if err != nil {
b.Fatal(err)
}
err = s.Close()
if err != nil {
b.Fatal(err)
}
}
}
func TestRejectDialBacklogFilled(t *testing.T) {
s, err := NewSocket("udp", "localhost:0")
if err != nil {
t.Fatal(err)
}
errChan := make(chan error)
dial := func() {
_, err := s.Dial(s.Addr().String())
require.Error(t, err)
errChan <- err
}
// Fill the backlog.
for range iter.N(backlog) {
go dial()
}
sleepWhile(&mu, func() bool { return len(s.backlog) < backlog })
select {
case err := <-errChan:
t.Fatalf("got premature error: %s", err)
default:
}
// One more connection should cause a dial attempt to get reset.
go dial()
err = <-errChan
assert.EqualError(t, err, "peer reset")
s.Close()
for range iter.N(backlog) {
<-errChan
}
}
// Make sure that we can reset AfterFunc timers, so we don't have to create
// brand new ones everytime they fire. Specifically for the Conn resend timer.
func TestResetAfterFuncTimer(t *testing.T) {
t.Parallel()
fired := make(chan struct{})
timer := time.AfterFunc(time.Millisecond, func() {
fired <- struct{}{}
})
<-fired
if timer.Reset(time.Millisecond) {
// The timer should have expired
t.FailNow()
}
<-fired
}
func connPairSocket(s *Socket) (initer, accepted net.Conn) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var err error
initer, err = s.Dial(s.Addr().String())
if err != nil {
panic(err)
}
}()
accepted, err := s.Accept()
if err != nil {
panic(err)
}
wg.Wait()
return
}
func connPair() (initer, accepted net.Conn) {
s, err := NewSocket("inproc", ":0")
if err != nil {
panic(err)
}
defer s.Close()
return connPairSocket(s)
}
// Check that peer sending FIN doesn't cause unread data to be dropped in a
// receiver.
func TestReadFinishedConn(t *testing.T) {
a, b := connPair()
defer a.Close()
defer b.Close()
mu.Lock()
originalAPDC := artificialPacketDropChance
artificialPacketDropChance = 1
mu.Unlock()
n, err := a.Write([]byte("hello"))
require.Equal(t, 5, n)
require.NoError(t, err)
n, err = a.Write([]byte("world"))
require.Equal(t, 5, n)
require.NoError(t, err)
mu.Lock()
artificialPacketDropChance = originalAPDC
mu.Unlock()
a.Close()
all, err := ioutil.ReadAll(b)
require.NoError(t, err)
require.EqualValues(t, "helloworld", all)
}
func TestCloseDetachesQuickly(t *testing.T) {
t.Parallel()
s, _ := NewSocket("udp", "localhost:0")
defer s.Close()
go func() {
a, _ := s.Dial(s.Addr().String())
log.Print("close a")
a.Close()
log.Print("closed a")
}()
b, _ := s.Accept()
b.Close()
sleepWhile(&mu, func() bool { return len(s.conns) != 0 })
}
// Check that closing, and resulting detach of a Conn doesn't close the parent
// Socket. We Accept, then close the connection and ensure it's detached. Then
// Accept again to check the Socket is still functional and unclosed.
func TestConnCloseUnclosedSocket(t *testing.T) {
t.Parallel()
s, err := NewSocket("udp", "localhost:0")
require.NoError(t, err)
defer func() {
assert.NoError(t, s.Close())
}()
// Prevents the dialing goroutine from closing its end of the Conn before
// we can check that it has been registered in the listener.
dialerSync := make(chan struct{})
go func() {
for range iter.N(2) {
c, err := Dial(s.Addr().String())
require.NoError(t, err)
<-dialerSync
err = c.Close()
require.NoError(t, err)
}
}()
for range iter.N(2) {
a, err := s.Accept()
require.NoError(t, err)
// We do this in a closure because we need to unlock Server.mu if the
// test failure exception is thrown. "Do as we say, not as we do" -Go
// team.
func() {
mu.Lock()
defer mu.Unlock()
require.Len(t, s.conns, 1)
}()
dialerSync <- struct{}{}
require.NoError(t, a.Close())
sleepWhile(&mu, func() bool { return len(s.conns) != 0 })
}
}
func TestPacketReadTimeout(t *testing.T) {
t.Parallel()
a, b := connPair()
_, err := a.Read(nil)
require.Contains(t, err.Error(), "timeout")
t.Log(err)
t.Log(a.Close())
t.Log(b.Close())
}
func sleepWhile(l sync.Locker, cond func() bool) {
sleepWhileTimeout(l, cond, -1)
for {
l.Lock()
val := cond()
l.Unlock()
if !val {
break
}
time.Sleep(time.Millisecond)
}
}
func sleepWhileTimeout(l sync.Locker, cond func() bool, timeout time.Duration) {
var deadline time.Time
if timeout >= 0 {
deadline = time.Now().Add(timeout)
}
for {
l.Lock()
val := cond()
l.Unlock()
if !val {
break
}
if !deadline.IsZero() && time.Now().After(deadline) {
break
}
time.Sleep(time.Millisecond)
}
}
func TestMain(m *testing.M) {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
WriteStatus(w)
})
setDefaultTestingDurations()
code := m.Run()
sleepWhileTimeout(&mu, func() bool {
return len(sockets) != 0
}, time.Second)
mu.Lock()
numSockets := len(sockets)
mu.Unlock()
if numSockets != 0 {
code = 1
WriteStatus(os.Stderr)
}
os.Exit(code)
}
func TestAcceptReturnsAfterClose(t *testing.T) {
s, err := NewSocket("", "")
require.NoError(t, err)
go s.Close()
_, err = s.Accept()
t.Log(err)
}
func TestWriteClose(t *testing.T) {
a, b := connPair()
defer a.Close()
defer b.Close()
a.Write([]byte("hiho"))
a.Close()
c, err := ioutil.ReadAll(b)
require.NoError(t, err)
require.EqualValues(t, "hiho", c)
b.Close()
}
// Check that Conn.Write fails when the PacketConn that Socket wraps is
// closed.
func TestWriteUnderlyingPacketConnClosed(t *testing.T) {
pc, err := listenPacket("inproc", "localhost:0")
require.NoError(t, err)
defer pc.Close()
s, err := NewSocketFromPacketConn(pc)
require.NoError(t, err)
defer s.Close()
dc, ac := connPairSocket(s)
defer dc.Close()
defer ac.Close()
pc.Close()
n, err := ac.Write([]byte("hello"))
assert.Equal(t, 0, n)
// It has to fail. I think it's a race between us writing to the real
// PacketConn and getting "closed", and the Socket destroying itself, and
// we get it's destroy error.
assert.Error(t, err)
_, err = dc.Read(nil)
assert.EqualError(t, err, "Socket destroyed")
}
func TestFillBuffers(t *testing.T) {
a, b := connPair()
defer b.Close()
var sent []byte
for {
buf := make([]byte, 100000)
io.ReadFull(rand.Reader, buf)
a.SetWriteDeadline(time.Now().Add(5 * time.Second))
n, err := a.Write(buf)
sent = append(sent, buf[:n]...)
if err != nil {
// Receiver will stop processing packets, packets will be dropped,
// and thus not acked.
assert.Equal(t, errAckTimeout, err)
break
}
require.NotEqual(t, 0, n)
}
t.Logf("buffered %d bytes", len(sent))
a.Close()
all, err := ioutil.ReadAll(b)
assert.NoError(t, err)
assert.EqualValues(t, len(sent), len(all))
assert.EqualValues(t, sent, all)
}
func TestConnLocalRemoteAddr(t *testing.T) {
a, b := connPair()
assert.EqualValues(t, "utp/inproc", a.LocalAddr().Network())
assert.EqualValues(t, "utp/inproc", a.RemoteAddr().Network())
assert.EqualValues(t, "utp/inproc", b.LocalAddr().Network())
assert.EqualValues(t, "utp/inproc", b.RemoteAddr().Network())
assert.EqualValues(t, a.LocalAddr().String(), b.RemoteAddr().String())
assert.EqualValues(t, b.LocalAddr().String(), a.RemoteAddr().String())
a.Close()
b.Close()
udpConn, err := net.ListenPacket("udp", "localhost:0")
require.NoError(t, err)
udpSock, err := NewSocketFromPacketConn(udpConn)
require.NoError(t, err)
a, b = connPairSocket(udpSock)
udpSock.Close()
assert.EqualValues(t, "utp/udp", a.LocalAddr().Network())
assert.EqualValues(t, "utp/udp", a.RemoteAddr().Network())
assert.EqualValues(t, "utp/udp", b.LocalAddr().Network())
assert.EqualValues(t, "utp/udp", b.RemoteAddr().Network())
assert.EqualValues(t, a.LocalAddr().String(), b.RemoteAddr().String())
assert.EqualValues(t, b.LocalAddr().String(), a.RemoteAddr().String())
a.Close()
b.Close()
}
func BenchmarkEchoLongBuffer(tb *testing.B) {
pristine := make([]byte, 20000000)
n, err := io.ReadFull(rand.Reader, pristine)
require.EqualValues(tb, len(pristine), n)
require.NoError(tb, err)
tb.SetBytes(int64(len(pristine)))
tb.ResetTimer()
for range iter.N(tb.N) {
func() {
a, b := connPair()
defer a.Close()
defer b.Close()
go func() {
n, err := io.Copy(b, b)
require.NoError(tb, err)
require.EqualValues(tb, len(pristine), n)
b.Close()
}()
go func() {
n, err := a.Write(pristine)
require.NoError(tb, err)
require.EqualValues(tb, len(pristine), n)
}()
echo := make([]byte, len(pristine))
n, err := io.ReadFull(a, echo)
a.Close()
assert.NoError(tb, err)
require.EqualValues(tb, len(echo), n)
require.True(tb, bytes.Equal(pristine, echo))
}()
}
}