Introduce internal/testcontext package (#227)

This function will return a `Context` that obeys the `go test -timeout` flag.
Rather than having an arbitrary timeout that can cause flakiness,
the test runner can specify the timeout and the tests will run as long as they need.
This commit is contained in:
Roxy Light
2026-03-03 10:23:11 -08:00
committed by GitHub
parent 8ccd56abc2
commit f1eb6f970a
20 changed files with 82 additions and 60 deletions
+5 -6
View File
@@ -1,7 +1,6 @@
package ngrok
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@@ -20,6 +19,7 @@ import (
muxado "golang.ngrok.com/muxado/v2"
"golang.ngrok.com/ngrok/v2/internal/testcontext"
"golang.ngrok.com/ngrok/v2/internal/tunnel/proto"
)
@@ -38,7 +38,7 @@ func TestDiagnoseTCPFailure(t *testing.T) {
d, ok := a.(Diagnoser)
require.True(t, ok, "agent should implement Diagnoser")
result, err := d.Diagnose(context.Background(), addr)
result, err := d.Diagnose(testcontext.ForTB(t), addr)
require.Error(t, err)
assert.True(t, IsTCPDiagnoseFailure(err))
assert.Equal(t, addr, result.Addr)
@@ -64,7 +64,7 @@ func TestDiagnoseTLSFailure(t *testing.T) {
d := a.(Diagnoser)
result, err := d.Diagnose(context.Background(), l.Addr().String())
result, err := d.Diagnose(testcontext.ForTB(t), l.Addr().String())
require.Error(t, err)
assert.True(t, IsTLSDiagnoseFailure(err))
assert.Equal(t, l.Addr().String(), result.Addr)
@@ -135,7 +135,7 @@ func TestDiagnoseMuxadoSuccess(t *testing.T) {
d := a.(Diagnoser)
result, err := d.Diagnose(context.Background(), l.Addr().String())
result, err := d.Diagnose(testcontext.ForTB(t), l.Addr().String())
require.NoError(t, err)
assert.Equal(t, l.Addr().String(), result.Addr)
assert.Equal(t, testRegion, result.Region)
@@ -167,8 +167,7 @@ func TestDiagnoseOnline(t *testing.T) {
d, ok := a.(Diagnoser)
require.True(t, ok)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ctx := testcontext.ForTB(t)
result, err := d.Diagnose(ctx, serverAddr)
require.NoError(t, err)
@@ -22,8 +22,7 @@ func TestAgentTLSTerminationIntegration(t *testing.T) {
cert := CreateTestCertificate(t)
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Setup synchronization primitives
@@ -18,8 +18,7 @@ import (
func TestEarlyResponseLargeUpload(t *testing.T) {
t.Parallel()
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
const maxBodySize = 1024
@@ -15,8 +15,7 @@ func TestEndpointClosingIntegration(t *testing.T) {
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Create a listener endpoint
@@ -13,8 +13,7 @@ func TestErrorCode(t *testing.T) {
SkipIfOffline(t)
t.Parallel()
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
// Create an endpoint with an invalid character ('@') in its URL
_, err := agent.Listen(ctx,
@@ -1,7 +1,6 @@
package integration_tests
import (
"context"
"os"
"testing"
"time"
@@ -9,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.ngrok.com/ngrok/v2"
"golang.ngrok.com/ngrok/v2/internal/testcontext"
)
// TestEventHandlingIntegration tests that events are properly emitted and received
@@ -60,8 +60,7 @@ func TestEventHandlingIntegration(t *testing.T) {
require.NoError(t, err, "Failed to create agent")
// Create a context with timeout for the test
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ctx := testcontext.ForTB(t)
// Connect the agent (should trigger a connect event)
t.Log("Connecting agent...")
+1 -2
View File
@@ -18,8 +18,7 @@ func TestForward(t *testing.T) {
// Mark this test for parallel execution
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Create a channel to signal when the server is ready
+2 -4
View File
@@ -24,8 +24,7 @@ func TestUpstreamProtocolHTTP2(t *testing.T) {
t.Parallel()
// Setup agent for this test
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Set up a test HTTP/2 server
@@ -92,8 +91,7 @@ func TestUpstreamProtocolHTTP2(t *testing.T) {
t.Parallel()
// Setup agent for this test
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Set up a test HTTP/2 server
@@ -15,8 +15,7 @@ func TestListenAndHTTPRequest(t *testing.T) {
// Mark this test for parallel execution
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Setup listener
@@ -16,8 +16,7 @@ func TestListenWithHTTPURL(t *testing.T) {
// Mark this test for parallel execution
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Setup listener with HTTP URL
@@ -16,8 +16,7 @@ func TestListenWithHTTPSURL(t *testing.T) {
// Mark this test for parallel execution
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Setup listener with HTTPS URL
@@ -16,8 +16,7 @@ func TestListenAndTCPConnection(t *testing.T) {
// Mark this test for parallel execution
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Setup TCP listener using the TCP scheme
@@ -221,10 +221,10 @@ func handleTCPConnection(t *testing.T, conn net.Conn, reader *bufio.Reader, srcA
}
// connectHTTPSClient connects to an HTTPS endpoint using an HTTP client
func connectHTTPSClient(t *testing.T, endpointURL string) {
func connectHTTPSClient(ctx context.Context, t *testing.T, endpointURL string) {
// Use MakeHTTPRequest to send test message
message := "Test message for PROXY protocol"
resp := MakeHTTPRequest(t, context.Background(), endpointURL, message)
resp := MakeHTTPRequest(t, ctx, endpointURL, message)
defer resp.Body.Close()
// Read the response
@@ -234,13 +234,13 @@ func connectHTTPSClient(t *testing.T, endpointURL string) {
}
// connectTCPClient connects to a TCP endpoint using a direct TCP connection
func connectTCPClient(t *testing.T, endpointURL string) {
func connectTCPClient(ctx context.Context, t *testing.T, endpointURL string) {
// For TCP, use direct TCP connection
u, err := url.Parse(endpointURL)
require.NoError(t, err, "Failed to parse URL")
// Connect to the endpoint using MakeTCPConnection
clientConn, err := MakeTCPConnection(t, context.Background(), u.Host)
clientConn, err := MakeTCPConnection(t, ctx, u.Host)
require.NoError(t, err, "Failed to connect to TCP endpoint")
defer clientConn.Close()
@@ -310,8 +310,7 @@ func TestProxyProtoIntegration(t *testing.T) {
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Create synchronization points
@@ -408,11 +407,11 @@ func TestProxyProtoIntegration(t *testing.T) {
// Connect to the endpoint with appropriate client based on scheme
switch {
case strings.HasPrefix(scheme, "https"):
connectHTTPSClient(t, endpointURL)
connectHTTPSClient(ctx, t, endpointURL)
case strings.HasPrefix(scheme, "tls"):
connectTLSClient(t, endpointURL)
default: // TCP
connectTCPClient(t, endpointURL)
connectTCPClient(ctx, t, endpointURL)
}
// Wait for the client address with timeout
+4 -4
View File
@@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/require"
"golang.ngrok.com/ngrok/v2"
"golang.ngrok.com/ngrok/v2/internal/testcontext"
)
// SkipIfOffline skips the test if NGROK_TEST_ONLINE environment variable is not set
@@ -30,7 +31,7 @@ func SkipIfOffline(t *testing.T) {
}
// SetupAgent creates and connects a new agent for testing
func SetupAgent(t *testing.T) (ngrok.Agent, context.Context, context.CancelFunc) {
func SetupAgent(t *testing.T) (ngrok.Agent, context.Context) {
// Skip if not running online tests
SkipIfOffline(t)
@@ -44,14 +45,13 @@ func SetupAgent(t *testing.T) (ngrok.Agent, context.Context, context.CancelFunc)
)
require.NoError(t, err, "Failed to create agent")
// Start a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx := testcontext.ForTB(t)
// Connect the agent
err = agent.Connect(ctx)
require.NoError(t, err, "Failed to connect agent")
return agent, ctx, cancel
return agent, ctx
}
// SetupListener sets up an ngrok listener with the specified options
@@ -52,8 +52,7 @@ func TestUpstreamDialer(t *testing.T) {
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Create a custom dialer that returns an error and has synchronization
@@ -23,8 +23,7 @@ func TestListenWithURLAndPooling(t *testing.T) {
t.Parallel()
// Setup agent
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer func() { _ = agent.Disconnect() }()
// Common URL for both endpoints - IMPORTANT: the exact same string must be used for both listeners
+1 -2
View File
@@ -18,8 +18,7 @@ import (
func TestWebSocketUpgrade(t *testing.T) {
t.Parallel()
agent, ctx, cancel := SetupAgent(t)
defer cancel()
agent, ctx := SetupAgent(t)
defer agent.Disconnect() //nolint:errcheck
// Start an upstream WebSocket server.
+11 -10
View File
@@ -17,6 +17,7 @@ import (
"golang.org/x/net/websocket"
"golang.ngrok.com/ngrok/v2/internal/legacy/config"
"golang.ngrok.com/ngrok/v2/internal/testcontext"
)
func newTestLogger(t *testing.T) *slog.Logger {
@@ -86,7 +87,7 @@ func serveHTTP(ctx context.Context, t *testing.T, connectOpts []ConnectOption, o
}
func TestTunnel(t *testing.T) {
ctx := context.Background()
ctx := testcontext.ForTB(t)
sess := setupSession(ctx, t)
tun := startTunnel(ctx, t, sess, config.HTTPEndpoint(
@@ -101,7 +102,7 @@ func TestTunnel(t *testing.T) {
}
func TestTunnelConnMetadata(t *testing.T) {
ctx := context.Background()
ctx := testcontext.ForTB(t)
sess := setupSession(ctx, t)
tun := startTunnel(ctx, t, sess, config.HTTPEndpoint())
@@ -141,7 +142,7 @@ func (f failPanic) FailNow() {
func TestTCP(t *testing.T) {
onlineTest(t)
ctx := context.Background()
ctx := testcontext.ForTB(t)
opts := config.TCPEndpoint()
@@ -168,7 +169,7 @@ func TestConnectionCallbacks(t *testing.T) {
// Don't run this one by default - it's timing-sensitive and prone to flakes
skipUnless(t, "NGROK_TEST_FLAKEY", "Skipping flakey network test")
ctx := context.Background()
ctx := testcontext.ForTB(t)
connects := 0
disconnectErrs := 0
disconnectNils := 0
@@ -201,7 +202,7 @@ type sketchyDialer struct {
}
func (sd *sketchyDialer) Dial(network, addr string) (net.Conn, error) {
return sd.DialContext(context.Background(), network, addr)
return sd.DialContext(context.TODO(), network, addr)
}
func (sd *sketchyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -218,7 +219,7 @@ func TestHeartbeatCallback(t *testing.T) {
t.Skip("Skipping long network test")
}
ctx := context.Background()
ctx := testcontext.ForTB(t)
heartbeats := 0
sess := setupSession(ctx, t,
WithHeartbeatHandler(func(ctx context.Context, sess Session, latency time.Duration) {
@@ -236,7 +237,7 @@ func TestHeartbeatCallback(t *testing.T) {
func TestPermanentErrors(t *testing.T) {
onlineTest(t)
var err error
ctx := context.Background()
ctx := testcontext.ForTB(t)
token := os.Getenv("NGROK_AUTHTOKEN")
sess, err := Connect(ctx, WithAuthtoken(token))
@@ -253,7 +254,7 @@ func TestRetryableErrors(t *testing.T) {
onlineTest(t)
var err error
// Set global context with a longer timeout just to prevent test from hanging forever
ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
ctx, cancel := context.WithTimeout(testcontext.ForTB(t), 8*time.Second)
defer cancel()
// Create a custom dialer with short timeout for invalid addresses
@@ -280,7 +281,7 @@ func TestRetryableErrors(t *testing.T) {
}
func TestNonExported(t *testing.T) {
ctx := context.Background()
ctx := testcontext.ForTB(t)
sess := setupSession(ctx, t)
@@ -294,7 +295,7 @@ func echo(ws *websocket.Conn) {
func TestWebsockets(t *testing.T) {
onlineTest(t)
ctx := context.Background()
ctx := testcontext.ForTB(t)
srv := &http.ServeMux{}
srv.Handle("/", helloHandler)
+37
View File
@@ -0,0 +1,37 @@
// Package testcontext provides a function to obtain a [context.Context] in a test.
package testcontext
import (
"context"
"testing"
"time"
)
// ForTB returns a [context.Context] that is canceled
// just before Cleanup-registered functions are called
// or shortly before the test deadline,
// whichever comes first.
func ForTB(tb testing.TB) context.Context {
ctx := tb.Context()
deadline, ok := tbDeadline(tb)
if !ok {
return ctx
}
ctx, cancel := context.WithDeadline(ctx, deadline.Add(-10*time.Second))
tb.Cleanup(cancel)
return ctx
}
func tbDeadline(tb testing.TB) (deadline time.Time, ok bool) {
d, ok := tb.(deadliner)
if !ok {
return time.Time{}, false
}
return d.Deadline()
}
type deadliner interface {
Deadline() (deadline time.Time, ok bool)
}
var _ deadliner = (*testing.T)(nil)
+2 -1
View File
@@ -8,6 +8,7 @@ import (
"time"
"golang.ngrok.com/muxado/v2"
"golang.ngrok.com/ngrok/v2/internal/testcontext"
)
type dummyStream struct{}
@@ -31,7 +32,7 @@ func TestHeartbeatTimeout(t *testing.T) {
}
func TestRawSessionCloseRace(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
ctx, cancel := context.WithTimeout(testcontext.ForTB(t), time.Second*1)
defer cancel()
// Since this is a race condition, run the test as many times as we can