TUN-9469: Centralize UDP origin proxy dialing as ingress service

Introduces a new `UDPOriginProxy` interface and `UDPOriginService`
to standardize how UDP connections are dialed to origins. Allows for
future overrides of the ingress service for specific dial destinations.

Simplifies dependency injection for UDP dialing throughout both datagram
v2 and v3 by using the same ingress service. Previous invocations called
into a DialUDP function in the ingress package that was a light
wrapper over `net.DialUDP`. Now a reference is passed into both datagram
controllers that allows more control over the DialUDP method.

Closes TUN-9469
This commit is contained in:
Devin Carr
2025-06-23 18:01:15 +00:00
parent 64fdc52855
commit b4a98b13fe
9 changed files with 93 additions and 40 deletions

View File

@@ -830,6 +830,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
sessionManager, sessionManager,
cfdflow.NewLimiter(0), cfdflow.NewLimiter(0),
datagramMuxer, datagramMuxer,
ingress.DefaultUDPDialer,
packetRouter, packetRouter,
15 * time.Second, 15 * time.Second,
0 * time.Second, 0 * time.Second,

View File

@@ -4,9 +4,11 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors" pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -32,6 +34,10 @@ const (
demuxChanCapacity = 16 demuxChanCapacity = 16
) )
var (
errInvalidDestinationIP = errors.New("unable to parse destination IP")
)
// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming // DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming
// connection streams. // connection streams.
type DatagramSessionHandler interface { type DatagramSessionHandler interface {
@@ -51,7 +57,10 @@ type datagramV2Connection struct {
// datagramMuxer mux/demux datagrams from quic connection // datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *cfdquic.DatagramMuxerV2 datagramMuxer *cfdquic.DatagramMuxerV2
packetRouter *ingress.PacketRouter // ingressUDPProxy acts as the origin dialer for UDP requests
ingressUDPProxy ingress.UDPOriginProxy
// packetRouter acts as the origin router for ICMP requests
packetRouter *ingress.PacketRouter
rpcTimeout time.Duration rpcTimeout time.Duration
streamWriteTimeout time.Duration streamWriteTimeout time.Duration
@@ -61,6 +70,7 @@ type datagramV2Connection struct {
func NewDatagramV2Connection(ctx context.Context, func NewDatagramV2Connection(ctx context.Context,
conn quic.Connection, conn quic.Connection,
ingressUDPProxy ingress.UDPOriginProxy,
icmpRouter ingress.ICMPRouter, icmpRouter ingress.ICMPRouter,
index uint8, index uint8,
rpcTimeout time.Duration, rpcTimeout time.Duration,
@@ -79,6 +89,7 @@ func NewDatagramV2Connection(ctx context.Context,
sessionManager: sessionManager, sessionManager: sessionManager,
flowLimiter: flowLimiter, flowLimiter: flowLimiter,
datagramMuxer: datagramMuxer, datagramMuxer: datagramMuxer,
ingressUDPProxy: ingressUDPProxy,
packetRouter: packetRouter, packetRouter: packetRouter,
rpcTimeout: rpcTimeout, rpcTimeout: rpcTimeout,
streamWriteTimeout: streamWriteTimeout, streamWriteTimeout: streamWriteTimeout,
@@ -128,12 +139,29 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
return nil, err return nil, err
} }
// We need to force the net.IP to IPv4 (if it's an IPv4 address) otherwise the net.IP conversion from capnp
// will be a IPv4-mapped-IPv6 address.
// In the case that the address is IPv6 we leave it untouched and parse it as normal.
ip := dstIP.To4()
if ip == nil {
ip = dstIP
}
// Parse the dstIP and dstPort into a netip.AddrPort
// This should never fail because the IP was already parsed as a valid net.IP
destAddr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Err(errInvalidDestinationIP).Msgf("Failed to parse destination proxy IP: %s", ip)
tracing.EndWithErrorStatus(registerSpan, errInvalidDestinationIP)
q.flowLimiter.Release()
return nil, errInvalidDestinationIP
}
dstAddrPort := netip.AddrPortFrom(destAddr, dstPort)
// Each session is a series of datagram from an eyeball to a dstIP:dstPort. // Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := ingress.DialUDP(dstIP, dstPort) originProxy, err := q.ingressUDPProxy.DialUDP(dstAddrPort)
if err != nil { if err != nil {
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) log.Err(err).Msgf("Failed to create udp proxy to %s", dstAddrPort)
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release() q.flowLimiter.Release()
return nil, err return nil, err

View File

@@ -13,6 +13,7 @@ import (
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
cfdflow "github.com/cloudflare/cloudflared/flow" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/mocks" "github.com/cloudflare/cloudflared/mocks"
) )
@@ -83,6 +84,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
datagramConn := NewDatagramV2Connection( datagramConn := NewDatagramV2Connection(
t.Context(), t.Context(),
conn, conn,
ingress.DefaultUDPDialer,
nil, nil,
0, 0,
0*time.Second, 0*time.Second,

View File

@@ -2,37 +2,57 @@ package ingress
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"github.com/rs/zerolog"
) )
type UDPProxy interface { // UDPOriginService provides a proxy UDP dialer to origin services while allowing reserved
io.ReadWriteCloser // services to be provided. These reserved services are assigned to specific [netip.AddrPort]s
LocalAddr() net.Addr // and provide their own [UDPOriginProxy]s to handle UDP origin dialing.
type UDPOriginService struct {
// Reserved services for reserved AddrPort values
reservedServices map[netip.AddrPort]UDPOriginProxy
// The default UDP Dialer used if no reserved services are found for an origin request.
defaultDialer UDPOriginProxy
logger *zerolog.Logger
} }
type udpProxy struct { // UDPOriginProxy provides a UDP dial operation to a requested addr.
*net.UDPConn type UDPOriginProxy interface {
DialUDP(addr netip.AddrPort) (*net.UDPConn, error)
} }
func DialUDP(dstIP net.IP, dstPort uint16) (UDPProxy, error) { func NewUDPOriginService(reserved map[netip.AddrPort]UDPOriginProxy, logger *zerolog.Logger) *UDPOriginService {
dstAddr := &net.UDPAddr{ return &UDPOriginService{
IP: dstIP, reservedServices: reserved,
Port: int(dstPort), defaultDialer: DefaultUDPDialer,
logger: logger,
} }
// We use nil as local addr to force runtime to find the best suitable local address IP given the destination
// address as context.
udpConn, err := net.DialUDP("udp", nil, dstAddr)
if err != nil {
return nil, fmt.Errorf("unable to create UDP proxy to origin (%v:%v): %w", dstIP, dstPort, err)
}
return &udpProxy{udpConn}, nil
} }
func DialUDPAddrPort(dest netip.AddrPort) (*net.UDPConn, error) { // SetUDPDialer updates the default UDP Dialer used.
// Typically used in unit testing.
func (s *UDPOriginService) SetDefaultDialer(dialer UDPOriginProxy) {
s.defaultDialer = dialer
}
// DialUDP will perform a dial UDP to the requested addr.
func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (*net.UDPConn, error) {
// Check to see if any reserved services are available for this addr and call their dialer instead.
if dialer, ok := s.reservedServices[addr]; ok {
return dialer.DialUDP(addr)
}
return s.defaultDialer.DialUDP(addr)
}
type defaultUDPDialer struct{}
var DefaultUDPDialer UDPOriginProxy = &defaultUDPDialer{}
func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (*net.UDPConn, error) {
addr := net.UDPAddrFromAddrPort(dest) addr := net.UDPAddrFromAddrPort(dest)
// We use nil as local addr to force runtime to find the best suitable local address IP given the destination // We use nil as local addr to force runtime to find the best suitable local address IP given the destination

View File

@@ -2,12 +2,11 @@ package v3
import ( import (
"errors" "errors"
"net"
"net/netip"
"sync" "sync"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
cfdflow "github.com/cloudflare/cloudflared/flow" cfdflow "github.com/cloudflare/cloudflared/flow"
@@ -38,18 +37,16 @@ type SessionManager interface {
UnregisterSession(requestID RequestID) UnregisterSession(requestID RequestID)
} }
type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error)
type sessionManager struct { type sessionManager struct {
sessions map[RequestID]Session sessions map[RequestID]Session
mutex sync.RWMutex mutex sync.RWMutex
originDialer DialUDP originDialer ingress.UDPOriginProxy
limiter cfdflow.Limiter limiter cfdflow.Limiter
metrics Metrics metrics Metrics
log *zerolog.Logger log *zerolog.Logger
} }
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager { func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer ingress.UDPOriginProxy, limiter cfdflow.Limiter) SessionManager {
return &sessionManager{ return &sessionManager{
sessions: make(map[RequestID]Session), sessions: make(map[RequestID]Session),
originDialer: originDialer, originDialer: originDialer,
@@ -76,7 +73,7 @@ func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram
} }
// Attempt to bind the UDP socket for the new session // Attempt to bind the UDP socket for the new session
origin, err := s.originDialer(request.Dest) origin, err := s.originDialer.DialUDP(request.Dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -20,7 +20,7 @@ import (
func TestRegisterSession(t *testing.T) { func TestRegisterSession(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)) manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0))
request := v3.UDPSessionRegistrationDatagram{ request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID, RequestID: testRequestID,
@@ -76,7 +76,7 @@ func TestRegisterSession(t *testing.T) {
func TestGetSession_Empty(t *testing.T) { func TestGetSession_Empty(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)) manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0))
_, err := manager.GetSession(testRequestID) _, err := manager.GetSession(testRequestID)
if !errors.Is(err, v3.ErrSessionNotFound) { if !errors.Is(err, v3.ErrSessionNotFound) {
@@ -93,7 +93,7 @@ func TestRegisterSessionRateLimit(t *testing.T) {
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows) flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
flowLimiterMock.EXPECT().Release().Times(0) flowLimiterMock.EXPECT().Release().Times(0)
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock) manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, flowLimiterMock)
request := v3.UDPSessionRegistrationDatagram{ request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID, RequestID: testRequestID,

View File

@@ -88,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
func TestDatagramConn_New(t *testing.T) { func TestDatagramConn_New(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
if conn == nil { if conn == nil {
t.Fatal("expected valid connection") t.Fatal("expected valid connection")
} }
@@ -97,7 +97,7 @@ func TestDatagramConn_New(t *testing.T) {
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
payload := []byte{0xef, 0xef} payload := []byte{0xef, 0xef}
err := conn.SendUDPSessionDatagram(payload) err := conn.SendUDPSessionDatagram(payload)
@@ -112,7 +112,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
require.NoError(t, err) require.NoError(t, err)
@@ -134,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
func TestDatagramConnServe_ApplicationClosed(t *testing.T) { func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
defer cancel() defer cancel()
@@ -150,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
defer cancel() defer cancel()
quic.ctx = ctx quic.ctx = ctx
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.Serve(t.Context()) err := conn.Serve(t.Context())
if !errors.Is(err, context.DeadlineExceeded) { if !errors.Is(err, context.DeadlineExceeded) {
@@ -161,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := &mockQuicConnReadError{err: net.ErrClosed} quic := &mockQuicConnReadError{err: net.ErrClosed}
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.Serve(t.Context()) err := conn.Serve(t.Context())
if !errors.Is(err, net.ErrClosed) { if !errors.Is(err, net.ErrClosed) {

View File

@@ -78,11 +78,14 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
edgeBindAddr := config.EdgeBindAddr edgeBindAddr := config.EdgeBindAddr
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer) datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter()) // No reserved ingress services for now, hence the nil
ingressUDPService := ingress.NewUDPOriginService(nil, config.Log)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingressUDPService, orchestrator.GetFlowLimiter())
edgeTunnelServer := EdgeTunnelServer{ edgeTunnelServer := EdgeTunnelServer{
config: config, config: config,
orchestrator: orchestrator, orchestrator: orchestrator,
ingressUDPProxy: ingressUDPService,
sessionManager: sessionManager, sessionManager: sessionManager,
datagramMetrics: datagramMetrics, datagramMetrics: datagramMetrics,
edgeAddrs: edgeIPs, edgeAddrs: edgeIPs,

View File

@@ -166,6 +166,7 @@ func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsN
type EdgeTunnelServer struct { type EdgeTunnelServer struct {
config *TunnelConfig config *TunnelConfig
orchestrator *orchestration.Orchestrator orchestrator *orchestration.Orchestrator
ingressUDPProxy ingress.UDPOriginProxy
sessionManager v3.SessionManager sessionManager v3.SessionManager
datagramMetrics v3.Metrics datagramMetrics v3.Metrics
edgeAddrHandler EdgeAddrHandler edgeAddrHandler EdgeAddrHandler
@@ -613,6 +614,7 @@ func (e *EdgeTunnelServer) serveQUIC(
datagramSessionManager = connection.NewDatagramV2Connection( datagramSessionManager = connection.NewDatagramV2Connection(
ctx, ctx,
conn, conn,
e.ingressUDPProxy,
e.config.ICMPRouterServer, e.config.ICMPRouterServer,
connIndex, connIndex,
e.config.RPCTimeout, e.config.RPCTimeout,