mirror of
https://github.com/ngrok/ngrok-go.git
synced 2026-05-17 16:50:45 +00:00
Merge pull request #225 from ngrok/bmps/sdk-diagnostics
feature: add diagnostics support
This commit is contained in:
+192
@@ -0,0 +1,192 @@
|
||||
package ngrok
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
muxado "golang.ngrok.com/muxado/v2"
|
||||
|
||||
"golang.ngrok.com/ngrok/v2/internal/legacy"
|
||||
tunnelclient "golang.ngrok.com/ngrok/v2/internal/tunnel/client"
|
||||
"golang.ngrok.com/ngrok/v2/internal/tunnel/proto"
|
||||
)
|
||||
|
||||
// DiagnoseResult holds the outcome of a successful diagnostic probe.
|
||||
type DiagnoseResult struct {
|
||||
// The address that was tested (ip:port or host:port).
|
||||
Addr string
|
||||
// Region reported by SrvInfo.
|
||||
Region string
|
||||
// Round-trip latency of the SrvInfo call.
|
||||
Latency time.Duration
|
||||
}
|
||||
|
||||
// diagnoseError is returned by [Diagnoser.Diagnose] when a probe step fails.
|
||||
// Use [IsTCPDiagnoseFailure], [IsTLSDiagnoseFailure], or
|
||||
// [IsMuxadoDiagnoseFailure] to determine which step failed.
|
||||
type diagnoseError struct {
|
||||
// Step is the probe step that failed: "tcp", "tls", or "muxado".
|
||||
Step string
|
||||
// Err is the underlying error.
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *diagnoseError) Error() string {
|
||||
return fmt.Sprintf("diagnose %s: %v", e.Step, e.Err)
|
||||
}
|
||||
|
||||
func (e *diagnoseError) Unwrap() error { return e.Err }
|
||||
|
||||
// IsTCPDiagnoseFailure reports whether err is a TCP-level probe failure.
|
||||
func IsTCPDiagnoseFailure(err error) bool {
|
||||
var de *diagnoseError
|
||||
return errors.As(err, &de) && de.Step == "tcp"
|
||||
}
|
||||
|
||||
// IsTLSDiagnoseFailure reports whether err is a TLS-level probe failure.
|
||||
func IsTLSDiagnoseFailure(err error) bool {
|
||||
var de *diagnoseError
|
||||
return errors.As(err, &de) && de.Step == "tls"
|
||||
}
|
||||
|
||||
// IsMuxadoDiagnoseFailure reports whether err is a muxado-level probe failure.
|
||||
func IsMuxadoDiagnoseFailure(err error) bool {
|
||||
var de *diagnoseError
|
||||
return errors.As(err, &de) && de.Step == "muxado"
|
||||
}
|
||||
|
||||
// Diagnoser is implemented by Agent types that support pre-connection
|
||||
// diagnostic probing. Use a type assertion to access it:
|
||||
//
|
||||
// d, ok := agent.(ngrok.Diagnoser)
|
||||
type Diagnoser interface {
|
||||
Agent
|
||||
|
||||
// Diagnose tests connectivity to addr by probing TCP, TLS, and the Muxado
|
||||
// tunnel protocol. It uses the Agent's configured TLS settings, CA roots,
|
||||
// and proxy/dialer settings.
|
||||
//
|
||||
// If addr is empty, the configured server address is probed.
|
||||
//
|
||||
// This method does NOT establish a persistent session or call Auth. It is
|
||||
// safe to call without affecting any existing connection.
|
||||
Diagnose(ctx context.Context, addr string) (DiagnoseResult, error)
|
||||
}
|
||||
|
||||
// Diagnose implements Diagnoser.
|
||||
func (a *agent) Diagnose(ctx context.Context, addr string) (DiagnoseResult, error) {
|
||||
connectAddr := cmp.Or(a.opts.connectURL, "connect.ngrok-agent.com:443")
|
||||
if addr == "" {
|
||||
addr = connectAddr
|
||||
}
|
||||
|
||||
// Derive the TLS ServerName from the configured connect hostname, not from
|
||||
// the addr under test (which may be an IP address that cannot be used for
|
||||
// SNI).
|
||||
serverName, _, err := net.SplitHostPort(connectAddr)
|
||||
if err != nil {
|
||||
// connectAddr has no port — use as-is.
|
||||
serverName = connectAddr
|
||||
}
|
||||
|
||||
dialer, err := a.buildDiagnosticDialer()
|
||||
if err != nil {
|
||||
return DiagnoseResult{}, err
|
||||
}
|
||||
|
||||
logger := cmp.Or(a.opts.logger, slog.Default())
|
||||
|
||||
return a.probeAddr(ctx, logger, dialer, serverName, addr)
|
||||
}
|
||||
|
||||
// buildDiagnosticDialer returns the effective dialer for probes, applying
|
||||
// proxy configuration without mutating agent state.
|
||||
func (a *agent) buildDiagnosticDialer() (Dialer, error) {
|
||||
baseDialer := cmp.Or(a.opts.dialer, Dialer(&net.Dialer{}))
|
||||
if a.opts.proxyURL == "" {
|
||||
return baseDialer, nil
|
||||
}
|
||||
parsedURL, err := url.Parse(a.opts.proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
proxyDialer, err := proxy.FromURL(parsedURL, baseDialer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize proxy: %w", err)
|
||||
}
|
||||
dialer, ok := proxyDialer.(Dialer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("proxy dialer is not compatible with ngrok Dialer interface")
|
||||
}
|
||||
return dialer, nil
|
||||
}
|
||||
|
||||
// probeAddr runs TCP → TLS → Muxado → SrvInfo for addr and returns a
|
||||
// DiagnoseResult on success, or a *DiagnoseError indicating which step failed.
|
||||
func (a *agent) probeAddr(ctx context.Context, logger *slog.Logger, dialer Dialer, serverName, addr string) (DiagnoseResult, error) {
|
||||
result := DiagnoseResult{Addr: addr}
|
||||
|
||||
// TCP
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return result, &diagnoseError{Step: "tcp", Err: err}
|
||||
}
|
||||
defer conn.Close() //nolint:errcheck
|
||||
|
||||
// Interrupt I/O if the context is cancelled or expires.
|
||||
stop := context.AfterFunc(ctx, func() {
|
||||
conn.SetDeadline(time.Now()) //nolint:errcheck
|
||||
})
|
||||
defer stop()
|
||||
|
||||
// TLS
|
||||
rootCAs := a.opts.connectCAs
|
||||
if rootCAs == nil {
|
||||
rootCAs = legacy.DefaultCAPool()
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
if a.opts.tlsConfig != nil {
|
||||
a.opts.tlsConfig(tlsCfg)
|
||||
}
|
||||
tlsConn := tls.Client(conn, tlsCfg)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return result, &diagnoseError{Step: "tls", Err: err}
|
||||
}
|
||||
|
||||
// Muxado + SrvInfo
|
||||
muxSess := muxado.Client(tlsConn, nil)
|
||||
raw := tunnelclient.NewRawSession(logger, muxSess, nil, nopSessionHandler{})
|
||||
defer raw.Close() //nolint:errcheck
|
||||
|
||||
start := time.Now()
|
||||
info, err := raw.SrvInfo()
|
||||
if err != nil {
|
||||
return result, &diagnoseError{Step: "muxado", Err: err}
|
||||
}
|
||||
result.Region = info.Region
|
||||
result.Latency = time.Since(start)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// nopSessionHandler is a minimal SessionHandler that ignores all server RPCs.
|
||||
// It is used by probeAddr, which never calls Accept() and therefore will never
|
||||
// dispatch to these methods.
|
||||
type nopSessionHandler struct{}
|
||||
|
||||
func (nopSessionHandler) OnStop(*proto.Stop, tunnelclient.HandlerRespFunc) {}
|
||||
func (nopSessionHandler) OnRestart(*proto.Restart, tunnelclient.HandlerRespFunc) {}
|
||||
func (nopSessionHandler) OnUpdate(*proto.Update, tunnelclient.HandlerRespFunc) {}
|
||||
func (nopSessionHandler) OnStopTunnel(*proto.StopTunnel, tunnelclient.HandlerRespFunc) {}
|
||||
@@ -0,0 +1,178 @@
|
||||
package ngrok
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
muxado "golang.ngrok.com/muxado/v2"
|
||||
|
||||
"golang.ngrok.com/ngrok/v2/internal/tunnel/proto"
|
||||
)
|
||||
|
||||
// TestDiagnoseTCPFailure verifies that a connection refused at the TCP level
|
||||
// is reported as a TCP step failure.
|
||||
func TestDiagnoseTCPFailure(t *testing.T) {
|
||||
// Bind and immediately close a listener so the port is unreachable.
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
_ = l.Close()
|
||||
|
||||
a, err := NewAgent()
|
||||
require.NoError(t, err)
|
||||
|
||||
d, ok := a.(Diagnoser)
|
||||
require.True(t, ok, "agent should implement Diagnoser")
|
||||
|
||||
result, err := d.Diagnose(context.Background(), addr)
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsTCPDiagnoseFailure(err))
|
||||
assert.Equal(t, addr, result.Addr)
|
||||
}
|
||||
|
||||
// TestDiagnoseTLSFailure verifies that a TCP-only server (no TLS) is reported
|
||||
// as a TLS step failure.
|
||||
func TestDiagnoseTLSFailure(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer l.Close() //nolint:errcheck
|
||||
|
||||
// Accept one connection and immediately close it.
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
a, err := NewAgent(WithAgentConnectURL(l.Addr().String()))
|
||||
require.NoError(t, err)
|
||||
|
||||
d := a.(Diagnoser)
|
||||
|
||||
result, err := d.Diagnose(context.Background(), l.Addr().String())
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsTLSDiagnoseFailure(err))
|
||||
assert.Equal(t, l.Addr().String(), result.Addr)
|
||||
}
|
||||
|
||||
// TestDiagnoseMuxadoSuccess verifies the full happy path: TCP → TLS → Muxado
|
||||
// → SrvInfo all succeed against a local test server.
|
||||
func TestDiagnoseMuxadoSuccess(t *testing.T) {
|
||||
// Generate a self-signed TLS certificate for the test server.
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "diagnose-test"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsServerCfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{{Certificate: [][]byte{certDER}, PrivateKey: priv}},
|
||||
}
|
||||
|
||||
l, err := tls.Listen("tcp", "localhost:0", tlsServerCfg)
|
||||
require.NoError(t, err)
|
||||
defer l.Close() //nolint:errcheck
|
||||
|
||||
const testRegion = "test-us"
|
||||
|
||||
// Run a minimal Muxado server that responds to a single SrvInfo RPC.
|
||||
muxadoDone := make(chan struct{})
|
||||
defer func() { <-muxadoDone }()
|
||||
go func() {
|
||||
defer close(muxadoDone)
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close() //nolint:errcheck
|
||||
|
||||
typed := muxado.NewTypedStreamSession(muxado.Server(conn, nil))
|
||||
for {
|
||||
stream, err := typed.AcceptTypedStream()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
streamType := proto.ReqType(stream.StreamType())
|
||||
if streamType == proto.SrvInfoReq {
|
||||
var req proto.SrvInfo
|
||||
_ = json.NewDecoder(stream).Decode(&req)
|
||||
assert.NoError(t, json.NewEncoder(stream).Encode(proto.SrvInfoResp{Region: testRegion}))
|
||||
assert.NoError(t, stream.Close())
|
||||
return
|
||||
}
|
||||
// Drain any other stream types (e.g. heartbeat).
|
||||
_ = stream.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
a, err := NewAgent(
|
||||
WithAgentConnectURL(l.Addr().String()),
|
||||
WithTLSConfig(func(c *tls.Config) { c.InsecureSkipVerify = true }),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
d := a.(Diagnoser)
|
||||
|
||||
result, err := d.Diagnose(context.Background(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, l.Addr().String(), result.Addr)
|
||||
assert.Equal(t, testRegion, result.Region)
|
||||
assert.Greater(t, result.Latency, time.Duration(0))
|
||||
}
|
||||
|
||||
// TestDiagnoseOnline connects to a live tunnel server and verifies the full
|
||||
// probe succeeds. Requires NGROK_TEST_ONLINE=1 or NGROK_TEST_ALL=1.
|
||||
func TestDiagnoseOnline(t *testing.T) {
|
||||
if os.Getenv("NGROK_TEST_ONLINE") == "" && os.Getenv("NGROK_TEST_ALL") == "" {
|
||||
t.Skip("skipping online test; set NGROK_TEST_ONLINE=1 to run")
|
||||
}
|
||||
|
||||
serverAddr := os.Getenv("NGROK_CONNECT_URL")
|
||||
if serverAddr == "" {
|
||||
serverAddr = "connect.ngrok-agent.com:443"
|
||||
}
|
||||
|
||||
agentOpts := []AgentOption{WithAgentConnectURL(serverAddr)}
|
||||
if os.Getenv("NGROK_TEST_INSECURE") != "" {
|
||||
agentOpts = append(agentOpts, WithTLSConfig(func(c *tls.Config) {
|
||||
c.InsecureSkipVerify = true
|
||||
}))
|
||||
}
|
||||
|
||||
a, err := NewAgent(agentOpts...)
|
||||
require.NoError(t, err)
|
||||
|
||||
d, ok := a.(Diagnoser)
|
||||
require.True(t, ok)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := d.Diagnose(ctx, serverAddr)
|
||||
require.NoError(t, err)
|
||||
t.Logf("addr=%s region=%s latency=%s", result.Addr, result.Region, result.Latency)
|
||||
assert.NotEmpty(t, result.Region)
|
||||
assert.Greater(t, result.Latency, time.Duration(0))
|
||||
}
|
||||
@@ -50,6 +50,14 @@ type Session interface {
|
||||
//go:embed ngrok.ca.crt
|
||||
var defaultCACert []byte
|
||||
|
||||
// DefaultCAPool returns an [x509.CertPool] containing ngrok's default CA
|
||||
// certificate. It is used as the fallback when no custom CA pool is provided.
|
||||
func DefaultCAPool() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AppendCertsFromPEM(defaultCACert)
|
||||
return pool
|
||||
}
|
||||
|
||||
const defaultServer = "connect.ngrok-agent.com:443"
|
||||
|
||||
var leastLatencyServer = regexp.MustCompile(`^connect\.([a-z]+?-)?ngrok-agent\.com(\.lan)?:443`)
|
||||
@@ -407,8 +415,7 @@ func Connect(ctx context.Context, opts ...ConnectOption) (Session, error) {
|
||||
}
|
||||
|
||||
if cfg.CAPool == nil {
|
||||
cfg.CAPool = x509.NewCertPool()
|
||||
cfg.CAPool.AppendCertsFromPEM(defaultCACert)
|
||||
cfg.CAPool = DefaultCAPool()
|
||||
}
|
||||
|
||||
if cfg.ServerAddr == "" {
|
||||
|
||||
Reference in New Issue
Block a user