diff --git a/diagnose_test.go b/diagnose_test.go index 443e659..f178067 100644 --- a/diagnose_test.go +++ b/diagnose_test.go @@ -1,7 +1,6 @@ package ngrok import ( - "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -20,6 +19,7 @@ import ( muxado "golang.ngrok.com/muxado/v2" + "golang.ngrok.com/ngrok/v2/internal/testcontext" "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) @@ -38,7 +38,7 @@ func TestDiagnoseTCPFailure(t *testing.T) { d, ok := a.(Diagnoser) require.True(t, ok, "agent should implement Diagnoser") - result, err := d.Diagnose(context.Background(), addr) + result, err := d.Diagnose(testcontext.ForTB(t), addr) require.Error(t, err) assert.True(t, IsTCPDiagnoseFailure(err)) assert.Equal(t, addr, result.Addr) @@ -64,7 +64,7 @@ func TestDiagnoseTLSFailure(t *testing.T) { d := a.(Diagnoser) - result, err := d.Diagnose(context.Background(), l.Addr().String()) + result, err := d.Diagnose(testcontext.ForTB(t), l.Addr().String()) require.Error(t, err) assert.True(t, IsTLSDiagnoseFailure(err)) assert.Equal(t, l.Addr().String(), result.Addr) @@ -135,7 +135,7 @@ func TestDiagnoseMuxadoSuccess(t *testing.T) { d := a.(Diagnoser) - result, err := d.Diagnose(context.Background(), l.Addr().String()) + result, err := d.Diagnose(testcontext.ForTB(t), l.Addr().String()) require.NoError(t, err) assert.Equal(t, l.Addr().String(), result.Addr) assert.Equal(t, testRegion, result.Region) @@ -167,8 +167,7 @@ func TestDiagnoseOnline(t *testing.T) { d, ok := a.(Diagnoser) require.True(t, ok) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() + ctx := testcontext.ForTB(t) result, err := d.Diagnose(ctx, serverAddr) require.NoError(t, err) diff --git a/internal/integration_tests/agent_tls_termination_test.go b/internal/integration_tests/agent_tls_termination_test.go index dcbe26e..b8fc23f 100644 --- a/internal/integration_tests/agent_tls_termination_test.go +++ b/internal/integration_tests/agent_tls_termination_test.go @@ -22,8 +22,7 @@ func TestAgentTLSTerminationIntegration(t *testing.T) { cert := CreateTestCertificate(t) // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Setup synchronization primitives diff --git a/internal/integration_tests/early_response_test.go b/internal/integration_tests/early_response_test.go index 48c1f06..d6b0cef 100644 --- a/internal/integration_tests/early_response_test.go +++ b/internal/integration_tests/early_response_test.go @@ -18,8 +18,7 @@ import ( func TestEarlyResponseLargeUpload(t *testing.T) { t.Parallel() - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() const maxBodySize = 1024 diff --git a/internal/integration_tests/endpoint_closing_test.go b/internal/integration_tests/endpoint_closing_test.go index e2b991d..b45f1fe 100644 --- a/internal/integration_tests/endpoint_closing_test.go +++ b/internal/integration_tests/endpoint_closing_test.go @@ -15,8 +15,7 @@ func TestEndpointClosingIntegration(t *testing.T) { t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Create a listener endpoint diff --git a/internal/integration_tests/error_code_test.go b/internal/integration_tests/error_code_test.go index 4d231a7..67e1b3d 100644 --- a/internal/integration_tests/error_code_test.go +++ b/internal/integration_tests/error_code_test.go @@ -13,8 +13,7 @@ func TestErrorCode(t *testing.T) { SkipIfOffline(t) t.Parallel() - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) // Create an endpoint with an invalid character ('@') in its URL _, err := agent.Listen(ctx, diff --git a/internal/integration_tests/event_handling_test.go b/internal/integration_tests/event_handling_test.go index 631f62f..9797c63 100644 --- a/internal/integration_tests/event_handling_test.go +++ b/internal/integration_tests/event_handling_test.go @@ -1,7 +1,6 @@ package integration_tests import ( - "context" "os" "testing" "time" @@ -9,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testcontext" ) // TestEventHandlingIntegration tests that events are properly emitted and received @@ -60,8 +60,7 @@ func TestEventHandlingIntegration(t *testing.T) { require.NoError(t, err, "Failed to create agent") // Create a context with timeout for the test - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + ctx := testcontext.ForTB(t) // Connect the agent (should trigger a connect event) t.Log("Connecting agent...") diff --git a/internal/integration_tests/forward_test.go b/internal/integration_tests/forward_test.go index 09c3a70..f384042 100644 --- a/internal/integration_tests/forward_test.go +++ b/internal/integration_tests/forward_test.go @@ -18,8 +18,7 @@ func TestForward(t *testing.T) { // Mark this test for parallel execution t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Create a channel to signal when the server is ready diff --git a/internal/integration_tests/http2_test.go b/internal/integration_tests/http2_test.go index b0b645e..903cfa1 100644 --- a/internal/integration_tests/http2_test.go +++ b/internal/integration_tests/http2_test.go @@ -24,8 +24,7 @@ func TestUpstreamProtocolHTTP2(t *testing.T) { t.Parallel() // Setup agent for this test - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Set up a test HTTP/2 server @@ -92,8 +91,7 @@ func TestUpstreamProtocolHTTP2(t *testing.T) { t.Parallel() // Setup agent for this test - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Set up a test HTTP/2 server diff --git a/internal/integration_tests/listen_http_test.go b/internal/integration_tests/listen_http_test.go index cdd1149..bba7f27 100644 --- a/internal/integration_tests/listen_http_test.go +++ b/internal/integration_tests/listen_http_test.go @@ -15,8 +15,7 @@ func TestListenAndHTTPRequest(t *testing.T) { // Mark this test for parallel execution t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Setup listener diff --git a/internal/integration_tests/listen_http_url_test.go b/internal/integration_tests/listen_http_url_test.go index 372329b..d0a0991 100644 --- a/internal/integration_tests/listen_http_url_test.go +++ b/internal/integration_tests/listen_http_url_test.go @@ -16,8 +16,7 @@ func TestListenWithHTTPURL(t *testing.T) { // Mark this test for parallel execution t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Setup listener with HTTP URL diff --git a/internal/integration_tests/listen_https_test.go b/internal/integration_tests/listen_https_test.go index 610713b..0fa7bb0 100644 --- a/internal/integration_tests/listen_https_test.go +++ b/internal/integration_tests/listen_https_test.go @@ -16,8 +16,7 @@ func TestListenWithHTTPSURL(t *testing.T) { // Mark this test for parallel execution t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Setup listener with HTTPS URL diff --git a/internal/integration_tests/listen_tcp_test.go b/internal/integration_tests/listen_tcp_test.go index 88e91b1..3fe9d08 100644 --- a/internal/integration_tests/listen_tcp_test.go +++ b/internal/integration_tests/listen_tcp_test.go @@ -16,8 +16,7 @@ func TestListenAndTCPConnection(t *testing.T) { // Mark this test for parallel execution t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Setup TCP listener using the TCP scheme diff --git a/internal/integration_tests/proxy_proto_test.go b/internal/integration_tests/proxy_proto_test.go index 0fdb40b..9be6b27 100644 --- a/internal/integration_tests/proxy_proto_test.go +++ b/internal/integration_tests/proxy_proto_test.go @@ -221,10 +221,10 @@ func handleTCPConnection(t *testing.T, conn net.Conn, reader *bufio.Reader, srcA } // connectHTTPSClient connects to an HTTPS endpoint using an HTTP client -func connectHTTPSClient(t *testing.T, endpointURL string) { +func connectHTTPSClient(ctx context.Context, t *testing.T, endpointURL string) { // Use MakeHTTPRequest to send test message message := "Test message for PROXY protocol" - resp := MakeHTTPRequest(t, context.Background(), endpointURL, message) + resp := MakeHTTPRequest(t, ctx, endpointURL, message) defer resp.Body.Close() // Read the response @@ -234,13 +234,13 @@ func connectHTTPSClient(t *testing.T, endpointURL string) { } // connectTCPClient connects to a TCP endpoint using a direct TCP connection -func connectTCPClient(t *testing.T, endpointURL string) { +func connectTCPClient(ctx context.Context, t *testing.T, endpointURL string) { // For TCP, use direct TCP connection u, err := url.Parse(endpointURL) require.NoError(t, err, "Failed to parse URL") // Connect to the endpoint using MakeTCPConnection - clientConn, err := MakeTCPConnection(t, context.Background(), u.Host) + clientConn, err := MakeTCPConnection(t, ctx, u.Host) require.NoError(t, err, "Failed to connect to TCP endpoint") defer clientConn.Close() @@ -310,8 +310,7 @@ func TestProxyProtoIntegration(t *testing.T) { t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Create synchronization points @@ -408,11 +407,11 @@ func TestProxyProtoIntegration(t *testing.T) { // Connect to the endpoint with appropriate client based on scheme switch { case strings.HasPrefix(scheme, "https"): - connectHTTPSClient(t, endpointURL) + connectHTTPSClient(ctx, t, endpointURL) case strings.HasPrefix(scheme, "tls"): connectTLSClient(t, endpointURL) default: // TCP - connectTCPClient(t, endpointURL) + connectTCPClient(ctx, t, endpointURL) } // Wait for the client address with timeout diff --git a/internal/integration_tests/test_utils.go b/internal/integration_tests/test_utils.go index 918385c..de43a6b 100644 --- a/internal/integration_tests/test_utils.go +++ b/internal/integration_tests/test_utils.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "golang.ngrok.com/ngrok/v2" + "golang.ngrok.com/ngrok/v2/internal/testcontext" ) // SkipIfOffline skips the test if NGROK_TEST_ONLINE environment variable is not set @@ -30,7 +31,7 @@ func SkipIfOffline(t *testing.T) { } // SetupAgent creates and connects a new agent for testing -func SetupAgent(t *testing.T) (ngrok.Agent, context.Context, context.CancelFunc) { +func SetupAgent(t *testing.T) (ngrok.Agent, context.Context) { // Skip if not running online tests SkipIfOffline(t) @@ -44,14 +45,13 @@ func SetupAgent(t *testing.T) (ngrok.Agent, context.Context, context.CancelFunc) ) require.NoError(t, err, "Failed to create agent") - // Start a context with timeout - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx := testcontext.ForTB(t) // Connect the agent err = agent.Connect(ctx) require.NoError(t, err, "Failed to connect agent") - return agent, ctx, cancel + return agent, ctx } // SetupListener sets up an ngrok listener with the specified options diff --git a/internal/integration_tests/upstream_dialer_test.go b/internal/integration_tests/upstream_dialer_test.go index 7537927..1d66e12 100644 --- a/internal/integration_tests/upstream_dialer_test.go +++ b/internal/integration_tests/upstream_dialer_test.go @@ -52,8 +52,7 @@ func TestUpstreamDialer(t *testing.T) { t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Create a custom dialer that returns an error and has synchronization diff --git a/internal/integration_tests/url_pooling_test.go b/internal/integration_tests/url_pooling_test.go index c8d85cb..a02df7e 100644 --- a/internal/integration_tests/url_pooling_test.go +++ b/internal/integration_tests/url_pooling_test.go @@ -23,8 +23,7 @@ func TestListenWithURLAndPooling(t *testing.T) { t.Parallel() // Setup agent - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer func() { _ = agent.Disconnect() }() // Common URL for both endpoints - IMPORTANT: the exact same string must be used for both listeners diff --git a/internal/integration_tests/websocket_test.go b/internal/integration_tests/websocket_test.go index 672aa00..145637c 100644 --- a/internal/integration_tests/websocket_test.go +++ b/internal/integration_tests/websocket_test.go @@ -18,8 +18,7 @@ import ( func TestWebSocketUpgrade(t *testing.T) { t.Parallel() - agent, ctx, cancel := SetupAgent(t) - defer cancel() + agent, ctx := SetupAgent(t) defer agent.Disconnect() //nolint:errcheck // Start an upstream WebSocket server. diff --git a/internal/legacy/online_test.go b/internal/legacy/online_test.go index b0ef51e..32ff1d9 100644 --- a/internal/legacy/online_test.go +++ b/internal/legacy/online_test.go @@ -17,6 +17,7 @@ import ( "golang.org/x/net/websocket" "golang.ngrok.com/ngrok/v2/internal/legacy/config" + "golang.ngrok.com/ngrok/v2/internal/testcontext" ) func newTestLogger(t *testing.T) *slog.Logger { @@ -86,7 +87,7 @@ func serveHTTP(ctx context.Context, t *testing.T, connectOpts []ConnectOption, o } func TestTunnel(t *testing.T) { - ctx := context.Background() + ctx := testcontext.ForTB(t) sess := setupSession(ctx, t) tun := startTunnel(ctx, t, sess, config.HTTPEndpoint( @@ -101,7 +102,7 @@ func TestTunnel(t *testing.T) { } func TestTunnelConnMetadata(t *testing.T) { - ctx := context.Background() + ctx := testcontext.ForTB(t) sess := setupSession(ctx, t) tun := startTunnel(ctx, t, sess, config.HTTPEndpoint()) @@ -141,7 +142,7 @@ func (f failPanic) FailNow() { func TestTCP(t *testing.T) { onlineTest(t) - ctx := context.Background() + ctx := testcontext.ForTB(t) opts := config.TCPEndpoint() @@ -168,7 +169,7 @@ func TestConnectionCallbacks(t *testing.T) { // Don't run this one by default - it's timing-sensitive and prone to flakes skipUnless(t, "NGROK_TEST_FLAKEY", "Skipping flakey network test") - ctx := context.Background() + ctx := testcontext.ForTB(t) connects := 0 disconnectErrs := 0 disconnectNils := 0 @@ -201,7 +202,7 @@ type sketchyDialer struct { } func (sd *sketchyDialer) Dial(network, addr string) (net.Conn, error) { - return sd.DialContext(context.Background(), network, addr) + return sd.DialContext(context.TODO(), network, addr) } func (sd *sketchyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { @@ -218,7 +219,7 @@ func TestHeartbeatCallback(t *testing.T) { t.Skip("Skipping long network test") } - ctx := context.Background() + ctx := testcontext.ForTB(t) heartbeats := 0 sess := setupSession(ctx, t, WithHeartbeatHandler(func(ctx context.Context, sess Session, latency time.Duration) { @@ -236,7 +237,7 @@ func TestHeartbeatCallback(t *testing.T) { func TestPermanentErrors(t *testing.T) { onlineTest(t) var err error - ctx := context.Background() + ctx := testcontext.ForTB(t) token := os.Getenv("NGROK_AUTHTOKEN") sess, err := Connect(ctx, WithAuthtoken(token)) @@ -253,7 +254,7 @@ func TestRetryableErrors(t *testing.T) { onlineTest(t) var err error // Set global context with a longer timeout just to prevent test from hanging forever - ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + ctx, cancel := context.WithTimeout(testcontext.ForTB(t), 8*time.Second) defer cancel() // Create a custom dialer with short timeout for invalid addresses @@ -280,7 +281,7 @@ func TestRetryableErrors(t *testing.T) { } func TestNonExported(t *testing.T) { - ctx := context.Background() + ctx := testcontext.ForTB(t) sess := setupSession(ctx, t) @@ -294,7 +295,7 @@ func echo(ws *websocket.Conn) { func TestWebsockets(t *testing.T) { onlineTest(t) - ctx := context.Background() + ctx := testcontext.ForTB(t) srv := &http.ServeMux{} srv.Handle("/", helloHandler) diff --git a/internal/testcontext/testcontext.go b/internal/testcontext/testcontext.go new file mode 100644 index 0000000..98b33fa --- /dev/null +++ b/internal/testcontext/testcontext.go @@ -0,0 +1,37 @@ +// Package testcontext provides a function to obtain a [context.Context] in a test. +package testcontext + +import ( + "context" + "testing" + "time" +) + +// ForTB returns a [context.Context] that is canceled +// just before Cleanup-registered functions are called +// or shortly before the test deadline, +// whichever comes first. +func ForTB(tb testing.TB) context.Context { + ctx := tb.Context() + deadline, ok := tbDeadline(tb) + if !ok { + return ctx + } + ctx, cancel := context.WithDeadline(ctx, deadline.Add(-10*time.Second)) + tb.Cleanup(cancel) + return ctx +} + +func tbDeadline(tb testing.TB) (deadline time.Time, ok bool) { + d, ok := tb.(deadliner) + if !ok { + return time.Time{}, false + } + return d.Deadline() +} + +type deadliner interface { + Deadline() (deadline time.Time, ok bool) +} + +var _ deadliner = (*testing.T)(nil) diff --git a/internal/tunnel/client/raw_session_test.go b/internal/tunnel/client/raw_session_test.go index 5f5926e..3275be8 100644 --- a/internal/tunnel/client/raw_session_test.go +++ b/internal/tunnel/client/raw_session_test.go @@ -8,6 +8,7 @@ import ( "time" "golang.ngrok.com/muxado/v2" + "golang.ngrok.com/ngrok/v2/internal/testcontext" ) type dummyStream struct{} @@ -31,7 +32,7 @@ func TestHeartbeatTimeout(t *testing.T) { } func TestRawSessionCloseRace(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) + ctx, cancel := context.WithTimeout(testcontext.ForTB(t), time.Second*1) defer cancel() // Since this is a race condition, run the test as many times as we can