diff --git a/diagnose_test.go b/diagnose_test.go index f178067..fb677a7 100644 --- a/diagnose_test.go +++ b/diagnose_test.go @@ -1,14 +1,8 @@ package ngrok import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" "encoding/json" - "math/big" "net" "os" "testing" @@ -20,6 +14,7 @@ import ( muxado "golang.ngrok.com/muxado/v2" "golang.ngrok.com/ngrok/v2/internal/testcontext" + "golang.ngrok.com/ngrok/v2/internal/tlstest" "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" ) @@ -73,22 +68,12 @@ func TestDiagnoseTLSFailure(t *testing.T) { // TestDiagnoseMuxadoSuccess verifies the full happy path: TCP → TLS → Muxado // → SrvInfo all succeed against a local test server. func TestDiagnoseMuxadoSuccess(t *testing.T) { - // Generate a self-signed TLS certificate for the test server. - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - template := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "diagnose-test"}, - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(24 * time.Hour), - DNSNames: []string{"localhost"}, + cert, err := tlstest.CreateCertificate() + if err != nil { + t.Fatal(err) } - certDER, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) - require.NoError(t, err) - tlsServerCfg := &tls.Config{ - Certificates: []tls.Certificate{{Certificate: [][]byte{certDER}, PrivateKey: priv}}, + Certificates: []tls.Certificate{*cert}, } l, err := tls.Listen("tcp", "localhost:0", tlsServerCfg) diff --git a/internal/integration_tests/agent_tls_termination_test.go b/internal/integration_tests/agent_tls_termination_test.go index c283e32..f46626b 100644 --- a/internal/integration_tests/agent_tls_termination_test.go +++ b/internal/integration_tests/agent_tls_termination_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.ngrok.com/ngrok/v2" "golang.ngrok.com/ngrok/v2/internal/testutil" + "golang.ngrok.com/ngrok/v2/internal/tlstest" ) // TestAgentTLSTerminationIntegration tests agent-based TLS termination with custom certificates @@ -19,7 +20,10 @@ func TestAgentTLSTerminationIntegration(t *testing.T) { t.Parallel() // Generate test certificate - cert := CreateTestCertificate(t) + cert, err := tlstest.CreateCertificate() + if err != nil { + t.Fatal(err) + } // Setup agent agent, ctx := SetupAgent(t) diff --git a/internal/integration_tests/proxy_proto_test.go b/internal/integration_tests/proxy_proto_test.go index f76daba..0fb0185 100644 --- a/internal/integration_tests/proxy_proto_test.go +++ b/internal/integration_tests/proxy_proto_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "golang.ngrok.com/ngrok/v2" "golang.ngrok.com/ngrok/v2/internal/testutil" + "golang.ngrok.com/ngrok/v2/internal/tlstest" ) // parseProxyProtocolHeader extracts client and server information from a PROXY protocol header. @@ -132,7 +133,10 @@ func verifyClientAddr(t *testing.T, clientAddr net.Addr) { // handleTLSConnection handles a TLS connection with PROXY protocol already read func handleTLSConnection(t *testing.T, conn net.Conn, reader *bufio.Reader, srcAddr net.Addr) { // Create a server TLS certificate for the handshake - servCert := CreateTestCertificate(t) + servCert, err := tlstest.CreateCertificate() + if err != nil { + t.Fatal(err) + } // Create TLS configuration for server config := &tls.Config{ diff --git a/internal/integration_tests/test_utils.go b/internal/integration_tests/test_utils.go index 42ffed2..f160f52 100644 --- a/internal/integration_tests/test_utils.go +++ b/internal/integration_tests/test_utils.go @@ -3,14 +3,9 @@ package integration_tests import ( "bufio" "context" - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" "fmt" "io" - "math/big" "net" "net/http" "os" @@ -118,36 +113,6 @@ func WaitForForwarderReady(t *testing.T, url string) { t.Logf("Forwarder endpoint didn't become ready in expected time, continuing anyway") } -// CreateTestCertificate creates a certificate for testing -func CreateTestCertificate(t *testing.T) *tls.Certificate { - // Generate a self-signed certificate for testing - privKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err, "Failed to generate private key") - - templ := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - CommonName: "localhost", - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - } - - certDER, err := x509.CreateCertificate(rand.Reader, &templ, &templ, &privKey.PublicKey, privKey) - require.NoError(t, err, "Failed to create certificate") - - cert := tls.Certificate{ - Certificate: [][]byte{certDER}, - PrivateKey: privKey, - } - - return &cert -} - // MakeTCPConnection establishes a TCP connection to the given address func MakeTCPConnection(t *testing.T, ctx context.Context, address string) (io.ReadWriteCloser, error) { t.Helper() diff --git a/internal/tlstest/tlstest.go b/internal/tlstest/tlstest.go new file mode 100644 index 0000000..af70ecf --- /dev/null +++ b/internal/tlstest/tlstest.go @@ -0,0 +1,44 @@ +// Package tlstest provides a function for creating a TLS certificate suitable for testing. +package tlstest + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "time" +) + +// CreateCertificate creates a self-signed localhost TLS certificate with a 24-hour expiry. +func CreateCertificate() (*tls.Certificate, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + now := time.Now() + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + NotBefore: now, + NotAfter: now.Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, err + } + + return &tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: privateKey, + Leaf: template, + }, nil +} diff --git a/listener_test.go b/listener_test.go index 11fca60..e4a6396 100644 --- a/listener_test.go +++ b/listener_test.go @@ -1,56 +1,20 @@ package ngrok import ( - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "math/big" "net" "testing" - "time" + + "golang.ngrok.com/ngrok/v2/internal/tlstest" ) -// createTestCertificate creates a self-signed certificate for testing -func createTestCertificate(t *testing.T) *tls.Certificate { - // Generate a private key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("Failed to generate private key: %v", err) - } - - // Create a certificate template - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "localhost"}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour * 24), // Valid for 24 hours - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - // Create a self-signed certificate - certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - if err != nil { - t.Fatalf("Failed to create certificate: %v", err) - } - - // Create a TLS certificate - cert := &tls.Certificate{ - Certificate: [][]byte{certBytes}, - PrivateKey: privateKey, - Leaf: &template, - } - - return cert -} - // TestWrapConnWithTLS tests the TLS connection wrapper func TestWrapConnWithTLS(t *testing.T) { // Create a test certificate - cert := createTestCertificate(t) + cert, err := tlstest.CreateCertificate() + if err != nil { + t.Fatal(err) + } // Create a pipe for testing serverConn, clientConn := net.Pipe()