diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 8b6f784a..3960d0db 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -220,7 +220,15 @@ func prepareTunnelConfig( resolvedRegion = endpoint } - dnsService := origins.NewDNSResolver(log) + warpRoutingConfig := ingress.NewWarpRoutingConfig(&cfg.WarpRouting) + + // Setup origin dialer service and virtual services + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: ingress.NewDialer(warpRoutingConfig), + TCPWriteTimeout: c.Duration(flags.WriteStreamTimeout), + }, log) + dnsService := origins.NewDNSResolverService(origins.NewDNSDialer(), log) + originDialerService.AddReservedService(dnsService, []netip.AddrPort{origins.VirtualDNSServiceAddr}) tunnelConfig := &supervisor.TunnelConfig{ ClientConfig: clientConfig, @@ -250,6 +258,7 @@ func prepareTunnelConfig( QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit), QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit), OriginDNSService: dnsService, + OriginDialerService: originDialerService, } icmpRouter, err := newICMPRouter(c, log) if err != nil { @@ -258,10 +267,10 @@ func prepareTunnelConfig( tunnelConfig.ICMPRouterServer = icmpRouter } orchestratorConfig := &orchestration.Config{ - Ingress: &ingressRules, - WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), - ConfigurationFlags: parseConfigFlags(c), - WriteTimeout: tunnelConfig.WriteStreamTimeout, + Ingress: &ingressRules, + WarpRouting: warpRoutingConfig, + OriginDialerService: originDialerService, + ConfigurationFlags: parseConfigFlags(c), } return tunnelConfig, orchestratorConfig, nil } diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 6e85fcd7..8765fd29 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -30,6 +30,7 @@ import ( "golang.org/x/net/nettest" "github.com/cloudflare/cloudflared/client" + "github.com/cloudflare/cloudflared/config" cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/datagramsession" @@ -823,6 +824,15 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) var connIndex uint8 = 0 packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, connIndex, &log) + testDefaultDialer := ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) datagramConn := &datagramV2Connection{ conn, @@ -830,7 +840,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) sessionManager, cfdflow.NewLimiter(0), datagramMuxer, - ingress.DefaultUDPDialer, + originDialer, packetRouter, 15 * time.Second, 0 * time.Second, diff --git a/connection/quic_datagram_v2.go b/connection/quic_datagram_v2.go index 94252551..aebead70 100644 --- a/connection/quic_datagram_v2.go +++ b/connection/quic_datagram_v2.go @@ -57,8 +57,8 @@ type datagramV2Connection struct { // datagramMuxer mux/demux datagrams from quic connection datagramMuxer *cfdquic.DatagramMuxerV2 - // ingressUDPProxy acts as the origin dialer for UDP requests - ingressUDPProxy ingress.UDPOriginProxy + // originDialer is the origin dialer for UDP requests + originDialer ingress.OriginUDPDialer // packetRouter acts as the origin router for ICMP requests packetRouter *ingress.PacketRouter @@ -70,7 +70,7 @@ type datagramV2Connection struct { func NewDatagramV2Connection(ctx context.Context, conn quic.Connection, - ingressUDPProxy ingress.UDPOriginProxy, + originDialer ingress.OriginUDPDialer, icmpRouter ingress.ICMPRouter, index uint8, rpcTimeout time.Duration, @@ -89,7 +89,7 @@ func NewDatagramV2Connection(ctx context.Context, sessionManager: sessionManager, flowLimiter: flowLimiter, datagramMuxer: datagramMuxer, - ingressUDPProxy: ingressUDPProxy, + originDialer: originDialer, packetRouter: packetRouter, rpcTimeout: rpcTimeout, streamWriteTimeout: streamWriteTimeout, @@ -159,7 +159,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID // Each session is a series of datagram from an eyeball to a dstIP:dstPort. // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. - originProxy, err := q.ingressUDPProxy.DialUDP(dstAddrPort) + originProxy, err := q.originDialer.DialUDP(dstAddrPort) if err != nil { log.Err(err).Msgf("Failed to create udp proxy to %s", dstAddrPort) tracing.EndWithErrorStatus(registerSpan, err) diff --git a/connection/quic_datagram_v2_test.go b/connection/quic_datagram_v2_test.go index 7e1f2f95..e4edac46 100644 --- a/connection/quic_datagram_v2_test.go +++ b/connection/quic_datagram_v2_test.go @@ -13,7 +13,6 @@ import ( "go.uber.org/mock/gomock" cfdflow "github.com/cloudflare/cloudflared/flow" - "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/mocks" ) @@ -84,7 +83,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) { datagramConn := NewDatagramV2Connection( t.Context(), conn, - ingress.DefaultUDPDialer, + nil, nil, 0, 0*time.Second, diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index f7e08004..139877ad 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -19,7 +19,7 @@ import ( type OriginConnection interface { // Stream should generally be implemented as a bidirectional io.Copy. Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) - Close() + Close() error } type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) @@ -48,16 +48,7 @@ func (tc *tcpConnection) Write(b []byte) (int, error) { } } - nBytes, err := tc.Conn.Write(b) - if err != nil { - tc.logger.Err(err).Msg("Error writing to the TCP connection") - } - - return nBytes, err -} - -func (tc *tcpConnection) Close() { - tc.Conn.Close() + return tc.Conn.Write(b) } // tcpOverWSConnection is an OriginConnection that streams to TCP over WS. @@ -75,8 +66,8 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri wsConn.Close() } -func (wc *tcpOverWSConnection) Close() { - wc.conn.Close() +func (wc *tcpOverWSConnection) Close() error { + return wc.conn.Close() } // socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS. @@ -95,5 +86,6 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. wsConn.Close() } -func (sp *socksProxyOverWSConnection) Close() { +func (sp *socksProxyOverWSConnection) Close() error { + return nil } diff --git a/ingress/origin_dialer.go b/ingress/origin_dialer.go new file mode 100644 index 00000000..36ade327 --- /dev/null +++ b/ingress/origin_dialer.go @@ -0,0 +1,146 @@ +package ingress + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/rs/zerolog" +) + +// OriginTCPDialer provides a TCP dial operation to a requested address. +type OriginTCPDialer interface { + DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) +} + +// OriginUDPDialer provides a UDP dial operation to a requested address. +type OriginUDPDialer interface { + DialUDP(addr netip.AddrPort) (net.Conn, error) +} + +// OriginDialer provides both TCP and UDP dial operations to an address. +type OriginDialer interface { + OriginTCPDialer + OriginUDPDialer +} + +type OriginConfig struct { + // The default Dialer used if no reserved services are found for an origin request. + DefaultDialer OriginDialer + // Timeout on write operations for TCP connections to the origin. + TCPWriteTimeout time.Duration +} + +// OriginDialerService provides a proxy TCP and UDP dialer to origin services while allowing reserved +// services to be provided. These reserved services are assigned to specific [netip.AddrPort]s +// and provide their own [OriginDialer]'s to handle origin dialing per protocol. +type OriginDialerService struct { + // Reserved TCP services for reserved AddrPort values + reservedTCPServices map[netip.AddrPort]OriginTCPDialer + // Reserved UDP services for reserved AddrPort values + reservedUDPServices map[netip.AddrPort]OriginUDPDialer + // The default Dialer used if no reserved services are found for an origin request + defaultDialer OriginDialer + defaultDialerM sync.RWMutex + // Write timeout for TCP connections + writeTimeout time.Duration + + logger *zerolog.Logger +} + +func NewOriginDialer(config OriginConfig, logger *zerolog.Logger) *OriginDialerService { + return &OriginDialerService{ + reservedTCPServices: map[netip.AddrPort]OriginTCPDialer{}, + reservedUDPServices: map[netip.AddrPort]OriginUDPDialer{}, + defaultDialer: config.DefaultDialer, + writeTimeout: config.TCPWriteTimeout, + logger: logger, + } +} + +// AddReservedService adds a reserved virtual service to dial to. +// Not locked and expected to be initialized before calling first dial and not afterwards. +func (d *OriginDialerService) AddReservedService(service OriginDialer, addrs []netip.AddrPort) { + for _, addr := range addrs { + d.reservedTCPServices[addr] = service + d.reservedUDPServices[addr] = service + } +} + +// UpdateDefaultDialer updates the default dialer. +func (d *OriginDialerService) UpdateDefaultDialer(dialer *Dialer) { + d.defaultDialerM.Lock() + defer d.defaultDialerM.Unlock() + d.defaultDialer = dialer +} + +// DialTCP will perform a dial TCP to the requested addr. +func (d *OriginDialerService) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + conn, err := d.dialTCP(ctx, addr) + if err != nil { + return nil, err + } + // Assign the write timeout for the TCP operations + return &tcpConnection{ + Conn: conn, + writeTimeout: d.writeTimeout, + logger: d.logger, + }, nil +} + +func (d *OriginDialerService) dialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + // Check to see if any reserved services are available for this addr and call their dialer instead. + if dialer, ok := d.reservedTCPServices[addr]; ok { + return dialer.DialTCP(ctx, addr) + } + d.defaultDialerM.RLock() + dialer := d.defaultDialer + d.defaultDialerM.RUnlock() + return dialer.DialTCP(ctx, addr) +} + +// DialUDP will perform a dial UDP to the requested addr. +func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) { + // Check to see if any reserved services are available for this addr and call their dialer instead. + if dialer, ok := d.reservedUDPServices[addr]; ok { + return dialer.DialUDP(addr) + } + d.defaultDialerM.RLock() + dialer := d.defaultDialer + d.defaultDialerM.RUnlock() + return dialer.DialUDP(addr) +} + +type Dialer struct { + Dialer net.Dialer +} + +func NewDialer(config WarpRoutingConfig) *Dialer { + return &Dialer{ + Dialer: net.Dialer{ + Timeout: config.ConnectTimeout.Duration, + KeepAlive: config.TCPKeepAlive.Duration, + }, + } +} + +func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String()) + if err != nil { + return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err) + } + + return conn, nil +} + +func (d *Dialer) DialUDP(dest netip.AddrPort) (net.Conn, error) { + conn, err := d.Dialer.Dial("udp", dest.String()) + if err != nil { + return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err) + } + + return conn, nil +} diff --git a/ingress/origin_udp_proxy.go b/ingress/origin_udp_proxy.go deleted file mode 100644 index 357f553b..00000000 --- a/ingress/origin_udp_proxy.go +++ /dev/null @@ -1,66 +0,0 @@ -package ingress - -import ( - "fmt" - "net" - "net/netip" - - "github.com/rs/zerolog" -) - -// UDPOriginService provides a proxy UDP dialer to origin services while allowing reserved -// services to be provided. These reserved services are assigned to specific [netip.AddrPort]s -// and provide their own [UDPOriginProxy]s to handle UDP origin dialing. -type UDPOriginService struct { - // Reserved services for reserved AddrPort values - reservedServices map[netip.AddrPort]UDPOriginProxy - // The default UDP Dialer used if no reserved services are found for an origin request. - defaultDialer UDPOriginProxy - - logger *zerolog.Logger -} - -// UDPOriginProxy provides a UDP dial operation to a requested addr. -type UDPOriginProxy interface { - DialUDP(addr netip.AddrPort) (net.Conn, error) -} - -func NewUDPOriginService(reserved map[netip.AddrPort]UDPOriginProxy, logger *zerolog.Logger) *UDPOriginService { - return &UDPOriginService{ - reservedServices: reserved, - defaultDialer: DefaultUDPDialer, - logger: logger, - } -} - -// SetUDPDialer updates the default UDP Dialer used. -// Typically used in unit testing. -func (s *UDPOriginService) SetDefaultDialer(dialer UDPOriginProxy) { - s.defaultDialer = dialer -} - -// DialUDP will perform a dial UDP to the requested addr. -func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (net.Conn, error) { - // Check to see if any reserved services are available for this addr and call their dialer instead. - if dialer, ok := s.reservedServices[addr]; ok { - return dialer.DialUDP(addr) - } - return s.defaultDialer.DialUDP(addr) -} - -type defaultUDPDialer struct{} - -var DefaultUDPDialer UDPOriginProxy = &defaultUDPDialer{} - -func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (net.Conn, error) { - addr := net.UDPAddrFromAddrPort(dest) - - // We use nil as local addr to force runtime to find the best suitable local address IP given the destination - // address as context. - udpConn, err := net.DialUDP("udp", nil, addr) - if err != nil { - return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err) - } - - return udpConn, nil -} diff --git a/ingress/origins/dns.go b/ingress/origins/dns.go index b034646a..a3936b9f 100644 --- a/ingress/origins/dns.go +++ b/ingress/origins/dns.go @@ -45,20 +45,28 @@ type DNSResolverService struct { address netip.AddrPort addressM sync.RWMutex - dialer ingress.UDPOriginProxy + dialer ingress.OriginDialer resolver peekResolver logger *zerolog.Logger } -func NewDNSResolver(logger *zerolog.Logger) *DNSResolverService { +func NewDNSResolverService(dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService { return &DNSResolverService{ address: defaultResolverAddr, - dialer: ingress.DefaultUDPDialer, + dialer: dialer, resolver: &resolver{dialFunc: net.Dial}, logger: logger, } } +func (s *DNSResolverService) DialTCP(ctx context.Context, _ netip.AddrPort) (net.Conn, error) { + s.addressM.RLock() + dest := s.address + s.addressM.RUnlock() + // The dialer ignores the provided address because the request will instead go to the local DNS resolver. + return s.dialer.DialTCP(ctx, dest) +} + func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) { s.addressM.RLock() dest := s.address @@ -155,3 +163,18 @@ func (r *resolver) peekDial(ctx context.Context, network, address string) (net.C r.address = address return r.dialFunc(network, address) } + +// NewDNSDialer creates a custom dialer for the DNS resolver service to utilize. +func NewDNSDialer() *ingress.Dialer { + return &ingress.Dialer{ + Dialer: net.Dialer{ + // We want short timeouts for the DNS requests + Timeout: 5 * time.Second, + // We do not want keep alive since the edge will not reuse TCP connections per request + KeepAlive: -1, + KeepAliveConfig: net.KeepAliveConfig{ + Enable: false, + }, + }, + } +} diff --git a/ingress/origins/dns_test.go b/ingress/origins/dns_test.go index db8ebf28..a137c814 100644 --- a/ingress/origins/dns_test.go +++ b/ingress/origins/dns_test.go @@ -12,7 +12,7 @@ import ( func TestDNSResolver_DefaultResolver(t *testing.T) { log := zerolog.Nop() - service := NewDNSResolver(&log) + service := NewDNSResolverService(NewDNSDialer(), &log) mockResolver := &mockPeekResolver{ address: "127.0.0.2:53", } @@ -24,7 +24,7 @@ func TestDNSResolver_DefaultResolver(t *testing.T) { func TestDNSResolver_UpdateResolverAddress(t *testing.T) { log := zerolog.Nop() - service := NewDNSResolver(&log) + service := NewDNSResolverService(NewDNSDialer(), &log) mockResolver := &mockPeekResolver{} service.resolver = mockResolver @@ -51,7 +51,7 @@ func TestDNSResolver_UpdateResolverAddress(t *testing.T) { func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) { log := zerolog.Nop() - service := NewDNSResolver(&log) + service := NewDNSResolverService(NewDNSDialer(), &log) mockResolver := &mockPeekResolver{} service.resolver = mockResolver @@ -77,7 +77,7 @@ func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) { func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) { log := zerolog.Nop() - service := NewDNSResolver(&log) + service := NewDNSResolverService(NewDNSDialer(), &log) resolverErr := errors.New("test resolver error") mockResolver := &mockPeekResolver{err: resolverErr} service.resolver = mockResolver @@ -93,13 +93,12 @@ func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) { } } -func TestDNSResolver_DialUsesResolvedAddress(t *testing.T) { +func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) { log := zerolog.Nop() - service := NewDNSResolver(&log) + mockDialer := &mockDialer{expected: defaultResolverAddr} + service := NewDNSResolverService(mockDialer, &log) mockResolver := &mockPeekResolver{} service.resolver = mockResolver - mockDialer := &mockDialer{expected: defaultResolverAddr} - service.dialer = mockDialer // Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53 _, err := service.DialUDP(netip.MustParseAddrPort("127.0.0.2:53")) @@ -108,6 +107,20 @@ func TestDNSResolver_DialUsesResolvedAddress(t *testing.T) { } } +func TestDNSResolver_DialTCPUsesResolvedAddress(t *testing.T) { + log := zerolog.Nop() + mockDialer := &mockDialer{expected: defaultResolverAddr} + service := NewDNSResolverService(mockDialer, &log) + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + // Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53 + _, err := service.DialTCP(t.Context(), netip.MustParseAddrPort("127.0.0.2:53")) + if err != nil { + t.Error(err) + } +} + type mockPeekResolver struct { err error address string @@ -126,6 +139,13 @@ type mockDialer struct { expected netip.AddrPort } +func (d *mockDialer) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + if d.expected != addr { + return nil, errors.New("unexpected address dialed") + } + return nil, nil +} + func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) { if d.expected != addr { return nil, errors.New("unexpected address dialed") diff --git a/orchestration/config.go b/orchestration/config.go index 04c7a0ab..b87b69a6 100644 --- a/orchestration/config.go +++ b/orchestration/config.go @@ -2,7 +2,6 @@ package orchestration import ( "encoding/json" - "time" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/ingress" @@ -20,9 +19,9 @@ type newLocalConfig struct { // Config is the original config as read and parsed by cloudflared. type Config struct { - Ingress *ingress.Ingress - WarpRouting ingress.WarpRoutingConfig - WriteTimeout time.Duration + Ingress *ingress.Ingress + WarpRouting ingress.WarpRoutingConfig + OriginDialerService *ingress.OriginDialerService // Extra settings used to configure this instance but that are not eligible for remotely management // ie. (--protocol, --loglevel, ...) diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index abfd1f9b..9840bd36 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -38,7 +38,9 @@ type Orchestrator struct { tags []pogs.Tag // flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit. flowLimiter cfdflow.Limiter - log *zerolog.Logger + // Origin dialer service to manage egress socket dialing. + originDialerService *ingress.OriginDialerService + log *zerolog.Logger // orchestrator must not handle any more updates after shutdownC is closed shutdownC <-chan struct{} @@ -50,18 +52,20 @@ func NewOrchestrator(ctx context.Context, config *Config, tags []pogs.Tag, internalRules []ingress.Rule, - log *zerolog.Logger) (*Orchestrator, error) { + log *zerolog.Logger, +) (*Orchestrator, error) { o := &Orchestrator{ // Lowest possible version, any remote configuration will have version higher than this // Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it // will start at version 0. - currentVersion: -1, - internalRules: internalRules, - config: config, - tags: tags, - flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows), - log: log, - shutdownC: ctx.Done(), + currentVersion: -1, + internalRules: internalRules, + config: config, + tags: tags, + flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows), + originDialerService: config.OriginDialerService, + log: log, + shutdownC: ctx.Done(), } if err := o.updateIngress(*config.Ingress, config.WarpRouting); err != nil { return nil, err @@ -175,7 +179,15 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i // Update the flow limit since the configuration might have changed o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows) - proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log) + // Update the origin dialer service with the new dialer settings + // We need to update the dialer here instead of creating a new instance of OriginDialerService because it has + // its own references and go routines. Specifically, the UDP dialer is a reference to this same service all the + // way into the datagram manager. Reconstructing the datagram manager is not something we currently provide during + // runtime in response to a configuration push except when starting a tunnel connection. + o.originDialerService.UpdateDefaultDialer(ingress.NewDialer(warpRouting)) + + // Create and replace the origin proxy with a new instance + proxy := proxy.NewOriginProxy(ingressRules, o.originDialerService, o.tags, o.flowLimiter, o.log) o.proxy.Store(proxy) o.config.Ingress = &ingressRules o.config.WarpRouting = warpRouting diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index aeed4860..7a14b2d4 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -41,6 +41,11 @@ var ( Value: "test", }, } + testDefaultDialer = ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) ) // TestUpdateConfiguration tests that @@ -50,8 +55,13 @@ var ( // - configurations can be deserialized // - receiving an old version is noop func TestUpdateConfiguration(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{ingress.NewManagementRule(management.New("management.argotunnel.com", false, "1.1.1.1:80", uuid.Nil, "", &testLogger, nil))}, &testLogger) require.NoError(t, err) @@ -179,8 +189,13 @@ func TestUpdateConfiguration(t *testing.T) { // Validates that a new version 0 will be applied if the configuration is loaded locally. // This will happen when a locally managed tunnel is migrated to remote configuration and receives its first configuration. func TestUpdateConfiguration_FromMigration(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{}, &testLogger) require.NoError(t, err) @@ -205,8 +220,13 @@ func TestUpdateConfiguration_FromMigration(t *testing.T) { // Validates that the default ingress rule will be set if there is no rule provided from the remote. func TestUpdateConfiguration_WithoutIngressRule(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{}, &testLogger) require.NoError(t, err) @@ -244,6 +264,11 @@ func TestConcurrentUpdateAndRead(t *testing.T) { require.NoError(t, err) defer tcpOrigin.Close() + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) + var ( configJSONV1 = []byte(fmt.Sprintf(` { @@ -296,7 +321,8 @@ func TestConcurrentUpdateAndRead(t *testing.T) { appliedV2 = make(chan struct{}) initConfig = &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } ) @@ -313,7 +339,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { go func() { serveTCPOrigin(t, tcpOrigin, &wg) }() - for i := 0; i < concurrentRequests; i++ { + for i := range concurrentRequests { originProxy, err := orchestrator.GetOriginProxy() require.NoError(t, err) wg.Add(1) @@ -323,48 +349,37 @@ func TestConcurrentUpdateAndRead(t *testing.T) { assert.NoError(t, err, "proxyHTTP %d failed %v", i, err) defer resp.Body.Close() - var warpRoutingDisabled bool // The response can be from initOrigin, http_status:204 or http_status:418 switch resp.StatusCode { - // v1 proxy, warp enabled + // v1 proxy case 200: body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, t.Name(), string(body)) - warpRoutingDisabled = false - // v2 proxy, warp disabled + // v2 proxy case 204: assert.Greater(t, i, concurrentRequests/4) - warpRoutingDisabled = true - // v3 proxy, warp enabled + // v3 proxy case 418: assert.Greater(t, i, concurrentRequests/2) - warpRoutingDisabled = false } // Once we have originProxy, it won't be changed by configuration updates. // We can infer the version by the ProxyHTTP response code pr, pw := io.Pipe() - w := newRespReadWriteFlusher() // Write TCP message and make sure it's echo back. This has to be done in a go routune since ProxyTCP doesn't // return until the stream is closed. - if !warpRoutingDisabled { - wg.Add(1) - go func() { - defer wg.Done() - defer pw.Close() - tcpEyeball(t, pw, tcpBody, w) - }() - } + wg.Add(1) + go func() { + defer wg.Done() + defer pw.Close() + tcpEyeball(t, pw, tcpBody, w) + }() err = proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), w, pr) - if warpRoutingDisabled { - assert.Error(t, err, "expect proxyTCP %d to return error", i) - } else { - assert.NoError(t, err, "proxyTCP %d failed %v", i, err) - } + assert.NoError(t, err, "proxyTCP %d failed %v", i, err) }(i, originProxy) if i == concurrentRequests/4 { @@ -406,39 +421,47 @@ func TestOverrideWarpRoutingConfigWithLocalValues(t *testing.T) { require.EqualValues(t, expectedValue, warpRouting["maxActiveFlows"]) } - remoteValue := uint64(100) - remoteIngress := ingress.Ingress{} + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) + + // All the possible values set for MaxActiveFlows from the various points that can provide it: + // 1. Initialized value + // 2. Local CLI flag config + // 3. Remote configuration value + initValue := uint64(0) + localValue := uint64(100) + remoteValue := uint64(500) + + initConfig := &Config{ + Ingress: &ingress.Ingress{}, + WarpRouting: ingress.WarpRoutingConfig{ + MaxActiveFlows: initValue, + }, + OriginDialerService: originDialer, + ConfigurationFlags: map[string]string{ + flags.MaxActiveFlows: fmt.Sprintf("%d", localValue), + }, + } + + // We expect the local configuration flag to be the starting value + orchestrator, err := NewOrchestrator(ctx, initConfig, testTags, []ingress.Rule{}, &testLogger) + require.NoError(t, err) + + assertMaxActiveFlows(orchestrator, localValue) + + // Assigning the MaxActiveFlows in the remote config should be ignored over the local config remoteWarpConfig := ingress.WarpRoutingConfig{ MaxActiveFlows: remoteValue, } - remoteConfig := &Config{ - Ingress: &remoteIngress, - WarpRouting: remoteWarpConfig, - ConfigurationFlags: map[string]string{}, - } - orchestrator, err := NewOrchestrator(ctx, remoteConfig, testTags, []ingress.Rule{}, &testLogger) - require.NoError(t, err) - assertMaxActiveFlows(orchestrator, remoteValue) - - // Add a local override for the maxActiveFlows - localValue := uint64(500) - remoteConfig.ConfigurationFlags[flags.MaxActiveFlows] = fmt.Sprintf("%d", localValue) // Force a configuration refresh - err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig) + err = orchestrator.updateIngress(ingress.Ingress{}, remoteWarpConfig) require.NoError(t, err) // Check the value being used is the local one assertMaxActiveFlows(orchestrator, localValue) - - // Remove local override for the maxActiveFlows - delete(remoteConfig.ConfigurationFlags, flags.MaxActiveFlows) - // Force a configuration refresh - err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig) - require.NoError(t, err) - - // Check the value being used is now the remote again - assertMaxActiveFlows(orchestrator, remoteValue) } func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Response, error) { @@ -546,6 +569,10 @@ func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int3 // TestClosePreviousProxies makes sure proxies started in the previous configuration version are shutdown func TestClosePreviousProxies(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) var ( hostname = "hello.tunnel1.org" configWithHelloWorld = []byte(fmt.Sprintf(` @@ -576,7 +603,8 @@ func TestClosePreviousProxies(t *testing.T) { } `) initConfig = &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } ) @@ -638,8 +666,13 @@ func TestPersistentConnection(t *testing.T) { hostname = "http://ws.tunnel.org" ) msg := t.Name() + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{}, &testLogger) require.NoError(t, err) @@ -752,8 +785,9 @@ func TestSerializeLocalConfig(t *testing.T) { ConfigurationFlags: map[string]string{"a": "b"}, } - result, _ := json.Marshal(c) - fmt.Println(string(result)) + result, err := json.Marshal(c) + require.NoError(t, err) + require.JSONEq(t, `{"__configuration_flags":{"a":"b"},"ingress":[],"warp-routing":{"connectTimeout":0,"tcpKeepAlive":0}}`, string(result)) } func wsEcho(w http.ResponseWriter, r *http.Request) { diff --git a/proxy/proxy.go b/proxy/proxy.go index 733ad385..e5d7fc6d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,11 +5,11 @@ import ( "fmt" "io" "net/http" + "net/netip" "strconv" "time" "github.com/pkg/errors" - pkgerrors "github.com/pkg/errors" "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -35,7 +35,7 @@ const ( // Proxy represents a means to Proxy between cloudflared and the origin services. type Proxy struct { ingressRules ingress.Ingress - warpRouting *ingress.WarpRoutingService + originDialer ingress.OriginTCPDialer tags []pogs.Tag flowLimiter cfdflow.Limiter log *zerolog.Logger @@ -44,21 +44,19 @@ type Proxy struct { // NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( ingressRules ingress.Ingress, - warpRouting ingress.WarpRoutingConfig, + originDialer ingress.OriginDialer, tags []pogs.Tag, flowLimiter cfdflow.Limiter, - writeTimeout time.Duration, log *zerolog.Logger, ) *Proxy { proxy := &Proxy{ ingressRules: ingressRules, + originDialer: originDialer, tags: tags, flowLimiter: flowLimiter, log: log, } - proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout) - return proxy } @@ -146,24 +144,18 @@ func (p *Proxy) ProxyHTTP( // ProxyTCP proxies to a TCP connection between the origin service and cloudflared. func (p *Proxy) ProxyTCP( ctx context.Context, - rwa connection.ReadWriteAcker, + conn connection.ReadWriteAcker, req *connection.TCPRequest, ) error { incrementTCPRequests() defer decrementTCPConcurrentRequests() - if p.warpRouting == nil { - err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`) - p.log.Error().Msg(err.Error()) - return err - } - logger := newTCPLogger(p.log, req) // Try to start a new flow if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil { logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy") - return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting") + return errors.Wrap(err, "failed to start tcp flow due to rate limiting") } defer p.flowLimiter.Release() @@ -173,7 +165,14 @@ func (p *Proxy) ProxyTCP( tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger) logger.Debug().Msg("tcp proxy stream started") - if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy, &logger); err != nil { + // Parse the destination into a netip.AddrPort + dest, err := netip.ParseAddrPort(req.Dest) + if err != nil { + logRequestError(&logger, err) + return err + } + + if err := p.proxyTCPStream(tracedCtx, conn, dest, p.originDialer, &logger); err != nil { logRequestError(&logger, err) return err } @@ -279,14 +278,14 @@ func (p *Proxy) proxyStream( tr *tracing.TracedContext, rwa connection.ReadWriteAcker, dest string, - connectionProxy ingress.StreamBasedOriginProxy, + originDialer ingress.StreamBasedOriginProxy, logger *zerolog.Logger, ) error { ctx := tr.Context _, connectSpan := tr.Tracer().Start(ctx, "stream-connect") start := time.Now() - originConn, err := connectionProxy.EstablishConnection(ctx, dest, logger) + originConn, err := originDialer.EstablishConnection(ctx, dest, logger) if err != nil { connectStreamErrors.Inc() tracing.EndWithErrorStatus(connectSpan, err) @@ -310,6 +309,45 @@ func (p *Proxy) proxyStream( return nil } +// proxyTCPStream proxies private network type TCP connections as a stream towards an available origin. +// +// This is different than proxyStream because it's not leveraged ingress rule services and uses the +// originDialer from OriginDialerService. +func (p *Proxy) proxyTCPStream( + tr *tracing.TracedContext, + tunnelConn connection.ReadWriteAcker, + dest netip.AddrPort, + originDialer ingress.OriginTCPDialer, + logger *zerolog.Logger, +) error { + ctx := tr.Context + _, connectSpan := tr.Tracer().Start(ctx, "stream-connect") + + start := time.Now() + originConn, err := originDialer.DialTCP(ctx, dest) + if err != nil { + connectStreamErrors.Inc() + tracing.EndWithErrorStatus(connectSpan, err) + return err + } + connectSpan.End() + defer originConn.Close() + logger.Debug().Msg("origin connection established") + + encodedSpans := tr.GetSpans() + + if err := tunnelConn.AckConnection(encodedSpans); err != nil { + connectStreamErrors.Inc() + return err + } + + connectLatency.Observe(float64(time.Since(start).Milliseconds())) + logger.Debug().Msg("proxy stream acknowledged") + + stream.Pipe(tunnelConn, originConn, logger) + return nil +} + func (p *Proxy) proxyLocalRequest(proxy ingress.HTTPLocalProxy, w connection.ResponseWriter, req *http.Request, isWebsocket bool) { if isWebsocket { // These headers are added since they are stripped off during an eyeball request to origintunneld, but they diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 56c9cab9..f5038122 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -33,17 +33,17 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) var ( - testTags = []pogs.Tag{{Name: "Name", Value: "value"}} - noWarpRouting = ingress.WarpRoutingConfig{} - testWarpRouting = ingress.WarpRoutingConfig{ - ConnectTimeout: config.CustomDuration{Duration: time.Second}, - } + testTags = []pogs.Tag{{Name: "Name", Value: "value"}} + testDefaultDialer = ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) ) type mockHTTPRespWriter struct { @@ -162,7 +162,12 @@ func TestProxySingleOrigin(t *testing.T) { require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) + + proxy := NewOriginProxy(ingressRule, originDialer, testTags, cfdflow.NewLimiter(0), &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) @@ -357,7 +362,7 @@ type MultipleIngressTest struct { } func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.UnvalidatedIngressRule, tests []MultipleIngressTest) { - ingress, err := ingress.ParseIngress(&config.Configuration{ + ingressRule, err := ingress.ParseIngress(&config.Configuration{ TunnelID: t.Name(), Ingress: unvalidatedIngress, }) @@ -366,9 +371,14 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat log := zerolog.Nop() ctx, cancel := context.WithCancel(t.Context()) - require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) + require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) + + proxy := NewOriginProxy(ingressRule, originDialer, testTags, cfdflow.NewLimiter(0), &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -416,7 +426,12 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) + + proxy := NewOriginProxy(ing, originDialer, testTags, cfdflow.NewLimiter(0), &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -467,7 +482,7 @@ func (r *replayer) Bytes() []byte { // WS - TCP: When a tcp based ingress is configured on the origin and the // eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access). func TestConnections(t *testing.T) { - logger := logger.Create(nil) + log := zerolog.Nop() replayer := &replayer{rw: bytes.NewBuffer([]byte{})} type args struct { ingressServiceScheme string @@ -475,9 +490,6 @@ func TestConnections(t *testing.T) { eyeballResponseWriter connection.ResponseWriter eyeballRequestBody io.ReadCloser - // Can be set to nil to show warp routing is not enabled. - warpRoutingService *ingress.WarpRoutingService - // eyeball connection type. connectionType connection.Type @@ -488,6 +500,11 @@ func TestConnections(t *testing.T) { flowLimiterResponse error } + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + type want struct { message []byte headers http.Header @@ -530,7 +547,6 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballResponseWriter: newTCPRespWriter(replayer), eyeballRequestBody: newTCPRequestBody([]byte("test2")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), connectionType: connection.TypeTCP, requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, @@ -548,7 +564,6 @@ func TestConnections(t *testing.T) { originService: runEchoWSService, // eyeballResponseWriter gets set after roundtrip dial. eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, }, @@ -601,23 +616,6 @@ func TestConnections(t *testing.T) { headers: map[string][]string{}, }, }, - { - name: "tcp-tcp proxy without warpRoutingService enabled", - args: args{ - ingressServiceScheme: "tcp://", - originService: runEchoTCPService, - eyeballResponseWriter: newTCPRespWriter(replayer), - eyeballRequestBody: newTCPRequestBody([]byte("test2")), - connectionType: connection.TypeTCP, - requestHeaders: map[string][]string{ - "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, - }, - }, - want: want{ - message: []byte{}, - err: true, - }, - }, { name: "ws-ws proxy when origin is different", args: args{ @@ -670,7 +668,6 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballResponseWriter: newTCPRespWriter(replayer), eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), connectionType: connection.TypeTCP, requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, @@ -693,7 +690,7 @@ func TestConnections(t *testing.T) { test.args.originService(t, ln) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) - _ = ingressRule.StartOrigins(logger, ctx.Done()) + _ = ingressRule.StartOrigins(&log, ctx.Done()) // Mock flow limiter ctrl := gomock.NewController(t) @@ -702,8 +699,7 @@ func TestConnections(t *testing.T) { flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse) flowLimiter.EXPECT().Release().AnyTimes() - proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger) - proxy.warpRouting = test.args.warpRoutingService + proxy := NewOriginProxy(ingressRule, originDialer, testTags, flowLimiter, &log) dest := ln.Addr().String() req, err := http.NewRequest( diff --git a/quic/v3/manager.go b/quic/v3/manager.go index a22adcea..d2456ff4 100644 --- a/quic/v3/manager.go +++ b/quic/v3/manager.go @@ -40,13 +40,13 @@ type SessionManager interface { type sessionManager struct { sessions map[RequestID]Session mutex sync.RWMutex - originDialer ingress.UDPOriginProxy + originDialer ingress.OriginUDPDialer limiter cfdflow.Limiter metrics Metrics log *zerolog.Logger } -func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer ingress.UDPOriginProxy, limiter cfdflow.Limiter) SessionManager { +func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer ingress.OriginUDPDialer, limiter cfdflow.Limiter) SessionManager { return &sessionManager{ sessions: make(map[RequestID]Session), originDialer: originDialer, diff --git a/quic/v3/manager_test.go b/quic/v3/manager_test.go index e6335b6a..80daf685 100644 --- a/quic/v3/manager_test.go +++ b/quic/v3/manager_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/mocks" cfdflow "github.com/cloudflare/cloudflared/flow" @@ -18,9 +19,21 @@ import ( v3 "github.com/cloudflare/cloudflared/quic/v3" ) +var ( + testDefaultDialer = ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) +) + func TestRegisterSession(t *testing.T) { log := zerolog.Nop() - manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)) + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + manager := v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)) request := v3.UDPSessionRegistrationDatagram{ RequestID: testRequestID, @@ -76,7 +89,11 @@ func TestRegisterSession(t *testing.T) { func TestGetSession_Empty(t *testing.T) { log := zerolog.Nop() - manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)) + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + manager := v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)) _, err := manager.GetSession(testRequestID) if !errors.Is(err, v3.ErrSessionNotFound) { @@ -86,6 +103,10 @@ func TestGetSession_Empty(t *testing.T) { func TestRegisterSessionRateLimit(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) ctrl := gomock.NewController(t) flowLimiterMock := mocks.NewMockLimiter(ctrl) @@ -93,7 +114,7 @@ func TestRegisterSessionRateLimit(t *testing.T) { flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows) flowLimiterMock.EXPECT().Release().Times(0) - manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, flowLimiterMock) + manager := v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, flowLimiterMock) request := v3.UDPSessionRegistrationDatagram{ RequestID: testRequestID, diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index e3ca81c0..729abd3c 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -88,7 +88,11 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP func TestDatagramConn_New(t *testing.T) { log := zerolog.Nop() - conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) if conn == nil { t.Fatal("expected valid connection") } @@ -96,8 +100,12 @@ func TestDatagramConn_New(t *testing.T) { func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) payload := []byte{0xef, 0xef} err := conn.SendUDPSessionDatagram(payload) @@ -111,8 +119,12 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) require.NoError(t, err) @@ -133,8 +145,12 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { func TestDatagramConnServe_ApplicationClosed(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() @@ -146,11 +162,15 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) { func TestDatagramConnServe_ConnectionClosed(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() quic.ctx = ctx - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.Serve(t.Context()) if !errors.Is(err, context.DeadlineExceeded) { @@ -160,8 +180,12 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := &mockQuicConnReadError{err: net.ErrClosed} - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DefaultUDPDialer, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.Serve(t.Context()) if !errors.Is(err, net.ErrClosed) { diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index fa70d29a..cb25d68a 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net" - "net/netip" "strings" "time" @@ -14,8 +13,6 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/ingress/origins" "github.com/cloudflare/cloudflared/orchestration" v3 "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/retry" @@ -81,16 +78,11 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer) - // Setup the reserved virtual origins - reservedServices := map[netip.AddrPort]ingress.UDPOriginProxy{} - reservedServices[origins.VirtualDNSServiceAddr] = config.OriginDNSService - ingressUDPService := ingress.NewUDPOriginService(reservedServices, config.Log) - sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingressUDPService, orchestrator.GetFlowLimiter()) + sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, config.OriginDialerService, orchestrator.GetFlowLimiter()) edgeTunnelServer := EdgeTunnelServer{ config: config, orchestrator: orchestrator, - ingressUDPProxy: ingressUDPService, sessionManager: sessionManager, datagramMetrics: datagramMetrics, edgeAddrs: edgeIPs, diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index ede0f8ed..b73eecb9 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -61,11 +61,12 @@ type TunnelConfig struct { NeedPQ bool - NamedTunnel *connection.TunnelProperties - ProtocolSelector connection.ProtocolSelector - EdgeTLSConfigs map[connection.Protocol]*tls.Config - ICMPRouterServer ingress.ICMPRouterServer - OriginDNSService *origins.DNSResolverService + NamedTunnel *connection.TunnelProperties + ProtocolSelector connection.ProtocolSelector + EdgeTLSConfigs map[connection.Protocol]*tls.Config + ICMPRouterServer ingress.ICMPRouterServer + OriginDNSService *origins.DNSResolverService + OriginDialerService *ingress.OriginDialerService RPCTimeout time.Duration WriteStreamTimeout time.Duration @@ -168,7 +169,6 @@ func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsN type EdgeTunnelServer struct { config *TunnelConfig orchestrator *orchestration.Orchestrator - ingressUDPProxy ingress.UDPOriginProxy sessionManager v3.SessionManager datagramMetrics v3.Metrics edgeAddrHandler EdgeAddrHandler @@ -616,7 +616,7 @@ func (e *EdgeTunnelServer) serveQUIC( datagramSessionManager = connection.NewDatagramV2Connection( ctx, conn, - e.ingressUDPProxy, + e.config.OriginDialerService, e.config.ICMPRouterServer, connIndex, e.config.RPCTimeout,