diff --git a/diagnose.go b/diagnose.go new file mode 100644 index 0000000..ea6cba5 --- /dev/null +++ b/diagnose.go @@ -0,0 +1,192 @@ +package ngrok + +import ( + "cmp" + "context" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "net" + "net/url" + "time" + + "golang.org/x/net/proxy" + + muxado "golang.ngrok.com/muxado/v2" + + "golang.ngrok.com/ngrok/v2/internal/legacy" + tunnelclient "golang.ngrok.com/ngrok/v2/internal/tunnel/client" + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" +) + +// DiagnoseResult holds the outcome of a successful diagnostic probe. +type DiagnoseResult struct { + // The address that was tested (ip:port or host:port). + Addr string + // Region reported by SrvInfo. + Region string + // Round-trip latency of the SrvInfo call. + Latency time.Duration +} + +// diagnoseError is returned by [Diagnoser.Diagnose] when a probe step fails. +// Use [IsTCPDiagnoseFailure], [IsTLSDiagnoseFailure], or +// [IsMuxadoDiagnoseFailure] to determine which step failed. +type diagnoseError struct { + // Step is the probe step that failed: "tcp", "tls", or "muxado". + Step string + // Err is the underlying error. + Err error +} + +func (e *diagnoseError) Error() string { + return fmt.Sprintf("diagnose %s: %v", e.Step, e.Err) +} + +func (e *diagnoseError) Unwrap() error { return e.Err } + +// IsTCPDiagnoseFailure reports whether err is a TCP-level probe failure. +func IsTCPDiagnoseFailure(err error) bool { + var de *diagnoseError + return errors.As(err, &de) && de.Step == "tcp" +} + +// IsTLSDiagnoseFailure reports whether err is a TLS-level probe failure. +func IsTLSDiagnoseFailure(err error) bool { + var de *diagnoseError + return errors.As(err, &de) && de.Step == "tls" +} + +// IsMuxadoDiagnoseFailure reports whether err is a muxado-level probe failure. +func IsMuxadoDiagnoseFailure(err error) bool { + var de *diagnoseError + return errors.As(err, &de) && de.Step == "muxado" +} + +// Diagnoser is implemented by Agent types that support pre-connection +// diagnostic probing. Use a type assertion to access it: +// +// d, ok := agent.(ngrok.Diagnoser) +type Diagnoser interface { + Agent + + // Diagnose tests connectivity to addr by probing TCP, TLS, and the Muxado + // tunnel protocol. It uses the Agent's configured TLS settings, CA roots, + // and proxy/dialer settings. + // + // If addr is empty, the configured server address is probed. + // + // This method does NOT establish a persistent session or call Auth. It is + // safe to call without affecting any existing connection. + Diagnose(ctx context.Context, addr string) (DiagnoseResult, error) +} + +// Diagnose implements Diagnoser. +func (a *agent) Diagnose(ctx context.Context, addr string) (DiagnoseResult, error) { + connectAddr := cmp.Or(a.opts.connectURL, "connect.ngrok-agent.com:443") + if addr == "" { + addr = connectAddr + } + + // Derive the TLS ServerName from the configured connect hostname, not from + // the addr under test (which may be an IP address that cannot be used for + // SNI). + serverName, _, err := net.SplitHostPort(connectAddr) + if err != nil { + // connectAddr has no port — use as-is. + serverName = connectAddr + } + + dialer, err := a.buildDiagnosticDialer() + if err != nil { + return DiagnoseResult{}, err + } + + logger := cmp.Or(a.opts.logger, slog.Default()) + + return a.probeAddr(ctx, logger, dialer, serverName, addr) +} + +// buildDiagnosticDialer returns the effective dialer for probes, applying +// proxy configuration without mutating agent state. +func (a *agent) buildDiagnosticDialer() (Dialer, error) { + baseDialer := cmp.Or(a.opts.dialer, Dialer(&net.Dialer{})) + if a.opts.proxyURL == "" { + return baseDialer, nil + } + parsedURL, err := url.Parse(a.opts.proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + proxyDialer, err := proxy.FromURL(parsedURL, baseDialer) + if err != nil { + return nil, fmt.Errorf("failed to initialize proxy: %w", err) + } + dialer, ok := proxyDialer.(Dialer) + if !ok { + return nil, fmt.Errorf("proxy dialer is not compatible with ngrok Dialer interface") + } + return dialer, nil +} + +// probeAddr runs TCP → TLS → Muxado → SrvInfo for addr and returns a +// DiagnoseResult on success, or a *DiagnoseError indicating which step failed. +func (a *agent) probeAddr(ctx context.Context, logger *slog.Logger, dialer Dialer, serverName, addr string) (DiagnoseResult, error) { + result := DiagnoseResult{Addr: addr} + + // TCP + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return result, &diagnoseError{Step: "tcp", Err: err} + } + defer conn.Close() //nolint:errcheck + + // Interrupt I/O if the context is cancelled or expires. + stop := context.AfterFunc(ctx, func() { + conn.SetDeadline(time.Now()) //nolint:errcheck + }) + defer stop() + + // TLS + rootCAs := a.opts.connectCAs + if rootCAs == nil { + rootCAs = legacy.DefaultCAPool() + } + tlsCfg := &tls.Config{ + RootCAs: rootCAs, + ServerName: serverName, + MinVersion: tls.VersionTLS12, + } + if a.opts.tlsConfig != nil { + a.opts.tlsConfig(tlsCfg) + } + tlsConn := tls.Client(conn, tlsCfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + return result, &diagnoseError{Step: "tls", Err: err} + } + + // Muxado + SrvInfo + muxSess := muxado.Client(tlsConn, nil) + raw := tunnelclient.NewRawSession(logger, muxSess, nil, nopSessionHandler{}) + defer raw.Close() //nolint:errcheck + + start := time.Now() + info, err := raw.SrvInfo() + if err != nil { + return result, &diagnoseError{Step: "muxado", Err: err} + } + result.Region = info.Region + result.Latency = time.Since(start) + return result, nil +} + +// nopSessionHandler is a minimal SessionHandler that ignores all server RPCs. +// It is used by probeAddr, which never calls Accept() and therefore will never +// dispatch to these methods. +type nopSessionHandler struct{} + +func (nopSessionHandler) OnStop(*proto.Stop, tunnelclient.HandlerRespFunc) {} +func (nopSessionHandler) OnRestart(*proto.Restart, tunnelclient.HandlerRespFunc) {} +func (nopSessionHandler) OnUpdate(*proto.Update, tunnelclient.HandlerRespFunc) {} +func (nopSessionHandler) OnStopTunnel(*proto.StopTunnel, tunnelclient.HandlerRespFunc) {} diff --git a/diagnose_test.go b/diagnose_test.go new file mode 100644 index 0000000..443e659 --- /dev/null +++ b/diagnose_test.go @@ -0,0 +1,178 @@ +package ngrok + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "math/big" + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + muxado "golang.ngrok.com/muxado/v2" + + "golang.ngrok.com/ngrok/v2/internal/tunnel/proto" +) + +// TestDiagnoseTCPFailure verifies that a connection refused at the TCP level +// is reported as a TCP step failure. +func TestDiagnoseTCPFailure(t *testing.T) { + // Bind and immediately close a listener so the port is unreachable. + l, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + addr := l.Addr().String() + _ = l.Close() + + a, err := NewAgent() + require.NoError(t, err) + + d, ok := a.(Diagnoser) + require.True(t, ok, "agent should implement Diagnoser") + + result, err := d.Diagnose(context.Background(), addr) + require.Error(t, err) + assert.True(t, IsTCPDiagnoseFailure(err)) + assert.Equal(t, addr, result.Addr) +} + +// TestDiagnoseTLSFailure verifies that a TCP-only server (no TLS) is reported +// as a TLS step failure. +func TestDiagnoseTLSFailure(t *testing.T) { + l, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer l.Close() //nolint:errcheck + + // Accept one connection and immediately close it. + go func() { + conn, err := l.Accept() + if err == nil { + _ = conn.Close() + } + }() + + a, err := NewAgent(WithAgentConnectURL(l.Addr().String())) + require.NoError(t, err) + + d := a.(Diagnoser) + + result, err := d.Diagnose(context.Background(), l.Addr().String()) + require.Error(t, err) + assert.True(t, IsTLSDiagnoseFailure(err)) + assert.Equal(t, l.Addr().String(), result.Addr) +} + +// TestDiagnoseMuxadoSuccess verifies the full happy path: TCP → TLS → Muxado +// → SrvInfo all succeed against a local test server. +func TestDiagnoseMuxadoSuccess(t *testing.T) { + // Generate a self-signed TLS certificate for the test server. + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "diagnose-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + DNSNames: []string{"localhost"}, + } + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + require.NoError(t, err) + + tlsServerCfg := &tls.Config{ + Certificates: []tls.Certificate{{Certificate: [][]byte{certDER}, PrivateKey: priv}}, + } + + l, err := tls.Listen("tcp", "localhost:0", tlsServerCfg) + require.NoError(t, err) + defer l.Close() //nolint:errcheck + + const testRegion = "test-us" + + // Run a minimal Muxado server that responds to a single SrvInfo RPC. + muxadoDone := make(chan struct{}) + defer func() { <-muxadoDone }() + go func() { + defer close(muxadoDone) + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() //nolint:errcheck + + typed := muxado.NewTypedStreamSession(muxado.Server(conn, nil)) + for { + stream, err := typed.AcceptTypedStream() + if err != nil { + return + } + streamType := proto.ReqType(stream.StreamType()) + if streamType == proto.SrvInfoReq { + var req proto.SrvInfo + _ = json.NewDecoder(stream).Decode(&req) + assert.NoError(t, json.NewEncoder(stream).Encode(proto.SrvInfoResp{Region: testRegion})) + assert.NoError(t, stream.Close()) + return + } + // Drain any other stream types (e.g. heartbeat). + _ = stream.Close() + } + }() + + a, err := NewAgent( + WithAgentConnectURL(l.Addr().String()), + WithTLSConfig(func(c *tls.Config) { c.InsecureSkipVerify = true }), + ) + require.NoError(t, err) + + d := a.(Diagnoser) + + result, err := d.Diagnose(context.Background(), l.Addr().String()) + require.NoError(t, err) + assert.Equal(t, l.Addr().String(), result.Addr) + assert.Equal(t, testRegion, result.Region) + assert.Greater(t, result.Latency, time.Duration(0)) +} + +// TestDiagnoseOnline connects to a live tunnel server and verifies the full +// probe succeeds. Requires NGROK_TEST_ONLINE=1 or NGROK_TEST_ALL=1. +func TestDiagnoseOnline(t *testing.T) { + if os.Getenv("NGROK_TEST_ONLINE") == "" && os.Getenv("NGROK_TEST_ALL") == "" { + t.Skip("skipping online test; set NGROK_TEST_ONLINE=1 to run") + } + + serverAddr := os.Getenv("NGROK_CONNECT_URL") + if serverAddr == "" { + serverAddr = "connect.ngrok-agent.com:443" + } + + agentOpts := []AgentOption{WithAgentConnectURL(serverAddr)} + if os.Getenv("NGROK_TEST_INSECURE") != "" { + agentOpts = append(agentOpts, WithTLSConfig(func(c *tls.Config) { + c.InsecureSkipVerify = true + })) + } + + a, err := NewAgent(agentOpts...) + require.NoError(t, err) + + d, ok := a.(Diagnoser) + require.True(t, ok) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + result, err := d.Diagnose(ctx, serverAddr) + require.NoError(t, err) + t.Logf("addr=%s region=%s latency=%s", result.Addr, result.Region, result.Latency) + assert.NotEmpty(t, result.Region) + assert.Greater(t, result.Latency, time.Duration(0)) +} diff --git a/internal/legacy/session.go b/internal/legacy/session.go index 301ce32..5cf0363 100644 --- a/internal/legacy/session.go +++ b/internal/legacy/session.go @@ -50,6 +50,14 @@ type Session interface { //go:embed ngrok.ca.crt var defaultCACert []byte +// DefaultCAPool returns an [x509.CertPool] containing ngrok's default CA +// certificate. It is used as the fallback when no custom CA pool is provided. +func DefaultCAPool() *x509.CertPool { + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(defaultCACert) + return pool +} + const defaultServer = "connect.ngrok-agent.com:443" var leastLatencyServer = regexp.MustCompile(`^connect\.([a-z]+?-)?ngrok-agent\.com(\.lan)?:443`) @@ -407,8 +415,7 @@ func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) { } if cfg.CAPool == nil { - cfg.CAPool = x509.NewCertPool() - cfg.CAPool.AppendCertsFromPEM(defaultCACert) + cfg.CAPool = DefaultCAPool() } if cfg.ServerAddr == "" {