Merge pull request #225 from ngrok/bmps/sdk-diagnostics

feature: add diagnostics support
This commit is contained in:
Benjamin Pollack
2026-02-26 13:25:41 -05:00
committed by GitHub
3 changed files with 379 additions and 2 deletions
+192
View File
@@ -0,0 +1,192 @@
package ngrok
import (
"cmp"
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"net/url"
"time"
"golang.org/x/net/proxy"
muxado "golang.ngrok.com/muxado/v2"
"golang.ngrok.com/ngrok/v2/internal/legacy"
tunnelclient "golang.ngrok.com/ngrok/v2/internal/tunnel/client"
"golang.ngrok.com/ngrok/v2/internal/tunnel/proto"
)
// DiagnoseResult holds the outcome of a successful diagnostic probe.
type DiagnoseResult struct {
// The address that was tested (ip:port or host:port).
Addr string
// Region reported by SrvInfo.
Region string
// Round-trip latency of the SrvInfo call.
Latency time.Duration
}
// diagnoseError is returned by [Diagnoser.Diagnose] when a probe step fails.
// Use [IsTCPDiagnoseFailure], [IsTLSDiagnoseFailure], or
// [IsMuxadoDiagnoseFailure] to determine which step failed.
type diagnoseError struct {
// Step is the probe step that failed: "tcp", "tls", or "muxado".
Step string
// Err is the underlying error.
Err error
}
func (e *diagnoseError) Error() string {
return fmt.Sprintf("diagnose %s: %v", e.Step, e.Err)
}
func (e *diagnoseError) Unwrap() error { return e.Err }
// IsTCPDiagnoseFailure reports whether err is a TCP-level probe failure.
func IsTCPDiagnoseFailure(err error) bool {
var de *diagnoseError
return errors.As(err, &de) && de.Step == "tcp"
}
// IsTLSDiagnoseFailure reports whether err is a TLS-level probe failure.
func IsTLSDiagnoseFailure(err error) bool {
var de *diagnoseError
return errors.As(err, &de) && de.Step == "tls"
}
// IsMuxadoDiagnoseFailure reports whether err is a muxado-level probe failure.
func IsMuxadoDiagnoseFailure(err error) bool {
var de *diagnoseError
return errors.As(err, &de) && de.Step == "muxado"
}
// Diagnoser is implemented by Agent types that support pre-connection
// diagnostic probing. Use a type assertion to access it:
//
// d, ok := agent.(ngrok.Diagnoser)
type Diagnoser interface {
Agent
// Diagnose tests connectivity to addr by probing TCP, TLS, and the Muxado
// tunnel protocol. It uses the Agent's configured TLS settings, CA roots,
// and proxy/dialer settings.
//
// If addr is empty, the configured server address is probed.
//
// This method does NOT establish a persistent session or call Auth. It is
// safe to call without affecting any existing connection.
Diagnose(ctx context.Context, addr string) (DiagnoseResult, error)
}
// Diagnose implements Diagnoser.
func (a *agent) Diagnose(ctx context.Context, addr string) (DiagnoseResult, error) {
connectAddr := cmp.Or(a.opts.connectURL, "connect.ngrok-agent.com:443")
if addr == "" {
addr = connectAddr
}
// Derive the TLS ServerName from the configured connect hostname, not from
// the addr under test (which may be an IP address that cannot be used for
// SNI).
serverName, _, err := net.SplitHostPort(connectAddr)
if err != nil {
// connectAddr has no port — use as-is.
serverName = connectAddr
}
dialer, err := a.buildDiagnosticDialer()
if err != nil {
return DiagnoseResult{}, err
}
logger := cmp.Or(a.opts.logger, slog.Default())
return a.probeAddr(ctx, logger, dialer, serverName, addr)
}
// buildDiagnosticDialer returns the effective dialer for probes, applying
// proxy configuration without mutating agent state.
func (a *agent) buildDiagnosticDialer() (Dialer, error) {
baseDialer := cmp.Or(a.opts.dialer, Dialer(&net.Dialer{}))
if a.opts.proxyURL == "" {
return baseDialer, nil
}
parsedURL, err := url.Parse(a.opts.proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
proxyDialer, err := proxy.FromURL(parsedURL, baseDialer)
if err != nil {
return nil, fmt.Errorf("failed to initialize proxy: %w", err)
}
dialer, ok := proxyDialer.(Dialer)
if !ok {
return nil, fmt.Errorf("proxy dialer is not compatible with ngrok Dialer interface")
}
return dialer, nil
}
// probeAddr runs TCP → TLS → Muxado → SrvInfo for addr and returns a
// DiagnoseResult on success, or a *DiagnoseError indicating which step failed.
func (a *agent) probeAddr(ctx context.Context, logger *slog.Logger, dialer Dialer, serverName, addr string) (DiagnoseResult, error) {
result := DiagnoseResult{Addr: addr}
// TCP
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return result, &diagnoseError{Step: "tcp", Err: err}
}
defer conn.Close() //nolint:errcheck
// Interrupt I/O if the context is cancelled or expires.
stop := context.AfterFunc(ctx, func() {
conn.SetDeadline(time.Now()) //nolint:errcheck
})
defer stop()
// TLS
rootCAs := a.opts.connectCAs
if rootCAs == nil {
rootCAs = legacy.DefaultCAPool()
}
tlsCfg := &tls.Config{
RootCAs: rootCAs,
ServerName: serverName,
MinVersion: tls.VersionTLS12,
}
if a.opts.tlsConfig != nil {
a.opts.tlsConfig(tlsCfg)
}
tlsConn := tls.Client(conn, tlsCfg)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return result, &diagnoseError{Step: "tls", Err: err}
}
// Muxado + SrvInfo
muxSess := muxado.Client(tlsConn, nil)
raw := tunnelclient.NewRawSession(logger, muxSess, nil, nopSessionHandler{})
defer raw.Close() //nolint:errcheck
start := time.Now()
info, err := raw.SrvInfo()
if err != nil {
return result, &diagnoseError{Step: "muxado", Err: err}
}
result.Region = info.Region
result.Latency = time.Since(start)
return result, nil
}
// nopSessionHandler is a minimal SessionHandler that ignores all server RPCs.
// It is used by probeAddr, which never calls Accept() and therefore will never
// dispatch to these methods.
type nopSessionHandler struct{}
func (nopSessionHandler) OnStop(*proto.Stop, tunnelclient.HandlerRespFunc) {}
func (nopSessionHandler) OnRestart(*proto.Restart, tunnelclient.HandlerRespFunc) {}
func (nopSessionHandler) OnUpdate(*proto.Update, tunnelclient.HandlerRespFunc) {}
func (nopSessionHandler) OnStopTunnel(*proto.StopTunnel, tunnelclient.HandlerRespFunc) {}
+178
View File
@@ -0,0 +1,178 @@
package ngrok
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"math/big"
"net"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
muxado "golang.ngrok.com/muxado/v2"
"golang.ngrok.com/ngrok/v2/internal/tunnel/proto"
)
// TestDiagnoseTCPFailure verifies that a connection refused at the TCP level
// is reported as a TCP step failure.
func TestDiagnoseTCPFailure(t *testing.T) {
// Bind and immediately close a listener so the port is unreachable.
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
addr := l.Addr().String()
_ = l.Close()
a, err := NewAgent()
require.NoError(t, err)
d, ok := a.(Diagnoser)
require.True(t, ok, "agent should implement Diagnoser")
result, err := d.Diagnose(context.Background(), addr)
require.Error(t, err)
assert.True(t, IsTCPDiagnoseFailure(err))
assert.Equal(t, addr, result.Addr)
}
// TestDiagnoseTLSFailure verifies that a TCP-only server (no TLS) is reported
// as a TLS step failure.
func TestDiagnoseTLSFailure(t *testing.T) {
l, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer l.Close() //nolint:errcheck
// Accept one connection and immediately close it.
go func() {
conn, err := l.Accept()
if err == nil {
_ = conn.Close()
}
}()
a, err := NewAgent(WithAgentConnectURL(l.Addr().String()))
require.NoError(t, err)
d := a.(Diagnoser)
result, err := d.Diagnose(context.Background(), l.Addr().String())
require.Error(t, err)
assert.True(t, IsTLSDiagnoseFailure(err))
assert.Equal(t, l.Addr().String(), result.Addr)
}
// TestDiagnoseMuxadoSuccess verifies the full happy path: TCP → TLS → Muxado
// → SrvInfo all succeed against a local test server.
func TestDiagnoseMuxadoSuccess(t *testing.T) {
// Generate a self-signed TLS certificate for the test server.
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "diagnose-test"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
DNSNames: []string{"localhost"},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
require.NoError(t, err)
tlsServerCfg := &tls.Config{
Certificates: []tls.Certificate{{Certificate: [][]byte{certDER}, PrivateKey: priv}},
}
l, err := tls.Listen("tcp", "localhost:0", tlsServerCfg)
require.NoError(t, err)
defer l.Close() //nolint:errcheck
const testRegion = "test-us"
// Run a minimal Muxado server that responds to a single SrvInfo RPC.
muxadoDone := make(chan struct{})
defer func() { <-muxadoDone }()
go func() {
defer close(muxadoDone)
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close() //nolint:errcheck
typed := muxado.NewTypedStreamSession(muxado.Server(conn, nil))
for {
stream, err := typed.AcceptTypedStream()
if err != nil {
return
}
streamType := proto.ReqType(stream.StreamType())
if streamType == proto.SrvInfoReq {
var req proto.SrvInfo
_ = json.NewDecoder(stream).Decode(&req)
assert.NoError(t, json.NewEncoder(stream).Encode(proto.SrvInfoResp{Region: testRegion}))
assert.NoError(t, stream.Close())
return
}
// Drain any other stream types (e.g. heartbeat).
_ = stream.Close()
}
}()
a, err := NewAgent(
WithAgentConnectURL(l.Addr().String()),
WithTLSConfig(func(c *tls.Config) { c.InsecureSkipVerify = true }),
)
require.NoError(t, err)
d := a.(Diagnoser)
result, err := d.Diagnose(context.Background(), l.Addr().String())
require.NoError(t, err)
assert.Equal(t, l.Addr().String(), result.Addr)
assert.Equal(t, testRegion, result.Region)
assert.Greater(t, result.Latency, time.Duration(0))
}
// TestDiagnoseOnline connects to a live tunnel server and verifies the full
// probe succeeds. Requires NGROK_TEST_ONLINE=1 or NGROK_TEST_ALL=1.
func TestDiagnoseOnline(t *testing.T) {
if os.Getenv("NGROK_TEST_ONLINE") == "" && os.Getenv("NGROK_TEST_ALL") == "" {
t.Skip("skipping online test; set NGROK_TEST_ONLINE=1 to run")
}
serverAddr := os.Getenv("NGROK_CONNECT_URL")
if serverAddr == "" {
serverAddr = "connect.ngrok-agent.com:443"
}
agentOpts := []AgentOption{WithAgentConnectURL(serverAddr)}
if os.Getenv("NGROK_TEST_INSECURE") != "" {
agentOpts = append(agentOpts, WithTLSConfig(func(c *tls.Config) {
c.InsecureSkipVerify = true
}))
}
a, err := NewAgent(agentOpts...)
require.NoError(t, err)
d, ok := a.(Diagnoser)
require.True(t, ok)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
result, err := d.Diagnose(ctx, serverAddr)
require.NoError(t, err)
t.Logf("addr=%s region=%s latency=%s", result.Addr, result.Region, result.Latency)
assert.NotEmpty(t, result.Region)
assert.Greater(t, result.Latency, time.Duration(0))
}
+9 -2
View File
@@ -50,6 +50,14 @@ type Session interface {
//go:embed ngrok.ca.crt
var defaultCACert []byte
// DefaultCAPool returns an [x509.CertPool] containing ngrok's default CA
// certificate. It is used as the fallback when no custom CA pool is provided.
func DefaultCAPool() *x509.CertPool {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(defaultCACert)
return pool
}
const defaultServer = "connect.ngrok-agent.com:443"
var leastLatencyServer = regexp.MustCompile(`^connect\.([a-z]+?-)?ngrok-agent\.com(\.lan)?:443`)
@@ -407,8 +415,7 @@ func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) {
}
if cfg.CAPool == nil {
cfg.CAPool = x509.NewCertPool()
cfg.CAPool.AppendCertsFromPEM(defaultCACert)
cfg.CAPool = DefaultCAPool()
}
if cfg.ServerAddr == "" {