diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 7961c813..8b6f784a 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -25,6 +25,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/ingress/origins" "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" @@ -219,6 +220,8 @@ func prepareTunnelConfig( resolvedRegion = endpoint } + dnsService := origins.NewDNSResolver(log) + tunnelConfig := &supervisor.TunnelConfig{ ClientConfig: clientConfig, GracePeriod: gracePeriod, @@ -246,6 +249,7 @@ func prepareTunnelConfig( DisableQUICPathMTUDiscovery: c.Bool(flags.QuicDisablePathMTUDiscovery), QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit), QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit), + OriginDNSService: dnsService, } icmpRouter, err := newICMPRouter(c, log) if err != nil { diff --git a/ingress/origin_udp_proxy.go b/ingress/origin_udp_proxy.go index eab7d783..357f553b 100644 --- a/ingress/origin_udp_proxy.go +++ b/ingress/origin_udp_proxy.go @@ -22,7 +22,7 @@ type UDPOriginService struct { // UDPOriginProxy provides a UDP dial operation to a requested addr. type UDPOriginProxy interface { - DialUDP(addr netip.AddrPort) (*net.UDPConn, error) + DialUDP(addr netip.AddrPort) (net.Conn, error) } func NewUDPOriginService(reserved map[netip.AddrPort]UDPOriginProxy, logger *zerolog.Logger) *UDPOriginService { @@ -40,7 +40,7 @@ func (s *UDPOriginService) SetDefaultDialer(dialer UDPOriginProxy) { } // DialUDP will perform a dial UDP to the requested addr. -func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (*net.UDPConn, error) { +func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (net.Conn, error) { // Check to see if any reserved services are available for this addr and call their dialer instead. if dialer, ok := s.reservedServices[addr]; ok { return dialer.DialUDP(addr) @@ -52,7 +52,7 @@ type defaultUDPDialer struct{} var DefaultUDPDialer UDPOriginProxy = &defaultUDPDialer{} -func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (*net.UDPConn, error) { +func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (net.Conn, error) { addr := net.UDPAddrFromAddrPort(dest) // We use nil as local addr to force runtime to find the best suitable local address IP given the destination diff --git a/ingress/origins/dns.go b/ingress/origins/dns.go new file mode 100644 index 00000000..b034646a --- /dev/null +++ b/ingress/origins/dns.go @@ -0,0 +1,157 @@ +package origins + +import ( + "context" + "net" + "net/netip" + "sync" + "time" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" +) + +const ( + // We need a DNS record: + // 1. That will be around for as long as cloudflared is + // 2. That Cloudflare controls: to allow us to make changes if needed + // 3. That is an external record to a typical customer's network: enforcing that the DNS request go to the + // local DNS resolver over any local /etc/host configurations setup. + // 4. That cloudflared would normally query: ensuring that users with a positive security model for DNS queries + // don't need to adjust anything. + // + // This hostname is one that used during the edge discovery process and as such satisfies the above constraints. + defaultLookupHost = "region1.v2.argotunnel.com" + defaultResolverPort uint16 = 53 + + // We want the refresh time to be short to accommodate DNS resolver changes locally, but not too frequent as to + // shuffle the resolver if multiple are configured. + refreshFreq = 5 * time.Minute + refreshTimeout = 5 * time.Second +) + +var ( + // Virtual DNS service address + VirtualDNSServiceAddr = netip.AddrPortFrom(netip.MustParseAddr("2606:4700:0cf1:2000:0000:0000:0000:0001"), 53) + + defaultResolverAddr = netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), defaultResolverPort) +) + +type netDial func(network string, address string) (net.Conn, error) + +// DNSResolverService will make DNS requests to the local DNS resolver via the Dial method. +type DNSResolverService struct { + address netip.AddrPort + addressM sync.RWMutex + + dialer ingress.UDPOriginProxy + resolver peekResolver + logger *zerolog.Logger +} + +func NewDNSResolver(logger *zerolog.Logger) *DNSResolverService { + return &DNSResolverService{ + address: defaultResolverAddr, + dialer: ingress.DefaultUDPDialer, + resolver: &resolver{dialFunc: net.Dial}, + logger: logger, + } +} + +func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) { + s.addressM.RLock() + dest := s.address + s.addressM.RUnlock() + // The dialer ignores the provided address because the request will instead go to the local DNS resolver. + return s.dialer.DialUDP(dest) +} + +// StartRefreshLoop is a routine that is expected to run in the background to update the DNS local resolver if +// adjusted while the cloudflared process is running. +func (s *DNSResolverService) StartRefreshLoop(ctx context.Context) { + // Call update first to load an address before handling traffic + err := s.update(ctx) + if err != nil { + s.logger.Err(err).Msg("Failed to initialize DNS local resolver") + } + for { + select { + case <-ctx.Done(): + return + case <-time.Tick(refreshFreq): + err := s.update(ctx) + if err != nil { + s.logger.Err(err).Msg("Failed to refresh DNS local resolver") + } + } + } +} + +func (s *DNSResolverService) update(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, refreshTimeout) + defer cancel() + // Make a standard DNS request to a well-known DNS record that will last a long time + _, err := s.resolver.lookupNetIP(ctx, defaultLookupHost) + if err != nil { + return err + } + + // Validate the address before updating internal reference + _, address := s.resolver.addr() + peekAddrPort, err := netip.ParseAddrPort(address) + if err == nil { + s.setAddress(peekAddrPort) + return nil + } + // It's possible that the address didn't have an attached port, attempt to parse just the address and use + // the default port 53 + peekAddr, err := netip.ParseAddr(address) + if err != nil { + return err + } + s.setAddress(netip.AddrPortFrom(peekAddr, defaultResolverPort)) + return nil +} + +// lock and update the address used for the local DNS resolver +func (s *DNSResolverService) setAddress(addr netip.AddrPort) { + s.addressM.Lock() + defer s.addressM.Unlock() + if s.address != addr { + s.logger.Debug().Msgf("Updating DNS local resolver: %s", addr) + } + s.address = addr +} + +type peekResolver interface { + addr() (network string, address string) + lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) +} + +// resolver is a shim that inspects the go runtime's DNS resolution process to capture the DNS resolver +// address used to complete a DNS request. +type resolver struct { + network string + address string + dialFunc netDial +} + +func (r *resolver) addr() (network string, address string) { + return r.network, r.address +} + +func (r *resolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) { + resolver := &net.Resolver{ + PreferGo: true, + // Use the peekDial to inspect the results of the DNS resolver used during the LookupIPAddr call. + Dial: r.peekDial, + } + return resolver.LookupNetIP(ctx, "ip", host) +} + +func (r *resolver) peekDial(ctx context.Context, network, address string) (net.Conn, error) { + r.network = network + r.address = address + return r.dialFunc(network, address) +} diff --git a/ingress/origins/dns_test.go b/ingress/origins/dns_test.go new file mode 100644 index 00000000..db8ebf28 --- /dev/null +++ b/ingress/origins/dns_test.go @@ -0,0 +1,134 @@ +package origins + +import ( + "context" + "errors" + "net" + "net/netip" + "testing" + + "github.com/rs/zerolog" +) + +func TestDNSResolver_DefaultResolver(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolver(&log) + mockResolver := &mockPeekResolver{ + address: "127.0.0.2:53", + } + service.resolver = mockResolver + if service.address != defaultResolverAddr { + t.Errorf("resolver address should be the default: %s, was: %s", defaultResolverAddr, service.address) + } +} + +func TestDNSResolver_UpdateResolverAddress(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolver(&log) + + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + expectedAddr := netip.MustParseAddrPort("127.0.0.2:53") + addresses := []string{ + "127.0.0.2:53", + "127.0.0.2", // missing port should be added (even though this is unlikely to happen) + } + + for _, addr := range addresses { + mockResolver.address = addr + // Update the resolver address + err := service.update(t.Context()) + if err != nil { + t.Error(err) + } + // Validate expected + if service.address != expectedAddr { + t.Errorf("resolver address should be: %s, was: %s", expectedAddr, service.address) + } + } +} + +func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolver(&log) + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + invalidAddresses := []string{ + "999.999.999.999", + "localhost", + "255.255.255", + } + + for _, addr := range invalidAddresses { + mockResolver.address = addr + // Update the resolver address should not update for these invalid addresses + err := service.update(t.Context()) + if err == nil { + t.Error("service update should throw an error") + } + // Validate expected + if service.address != defaultResolverAddr { + t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address) + } + } +} + +func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolver(&log) + resolverErr := errors.New("test resolver error") + mockResolver := &mockPeekResolver{err: resolverErr} + service.resolver = mockResolver + + // Update the resolver address should not update when the resolver cannot complete the lookup + err := service.update(t.Context()) + if err != resolverErr { + t.Error("service update should throw an error") + } + // Validate expected + if service.address != defaultResolverAddr { + t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address) + } +} + +func TestDNSResolver_DialUsesResolvedAddress(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolver(&log) + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + mockDialer := &mockDialer{expected: defaultResolverAddr} + service.dialer = mockDialer + + // Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53 + _, err := service.DialUDP(netip.MustParseAddrPort("127.0.0.2:53")) + if err != nil { + t.Error(err) + } +} + +type mockPeekResolver struct { + err error + address string +} + +func (r *mockPeekResolver) addr() (network, address string) { + return "udp", r.address +} + +func (r *mockPeekResolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) { + // We can return an empty result as it doesn't matter as long as the lookup doesn't fail + return []netip.Addr{}, r.err +} + +type mockDialer struct { + expected netip.AddrPort +} + +func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) { + if d.expected != addr { + return nil, errors.New("unexpected address dialed") + } + return nil, nil +} diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 965b8c0d..fa70d29a 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "strings" "time" @@ -14,6 +15,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/ingress/origins" "github.com/cloudflare/cloudflared/orchestration" v3 "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/retry" @@ -78,8 +80,11 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato edgeBindAddr := config.EdgeBindAddr datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer) - // No reserved ingress services for now, hence the nil - ingressUDPService := ingress.NewUDPOriginService(nil, config.Log) + + // Setup the reserved virtual origins + reservedServices := map[netip.AddrPort]ingress.UDPOriginProxy{} + reservedServices[origins.VirtualDNSServiceAddr] = config.OriginDNSService + ingressUDPService := ingress.NewUDPOriginService(reservedServices, config.Log) sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingressUDPService, orchestrator.GetFlowLimiter()) edgeTunnelServer := EdgeTunnelServer{ @@ -128,6 +133,9 @@ func (s *Supervisor) Run( }() } + // Setup DNS Resolver refresh + go s.config.OriginDNSService.StartRefreshLoop(ctx) + if err := s.initialize(ctx, connectedSignal); err != nil { if err == errEarlyShutdown { return nil diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 18c294c5..ede0f8ed 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -24,6 +24,7 @@ import ( "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/fips" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/ingress/origins" "github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/orchestration" quicpogs "github.com/cloudflare/cloudflared/quic" @@ -64,6 +65,7 @@ type TunnelConfig struct { ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config ICMPRouterServer ingress.ICMPRouterServer + OriginDNSService *origins.DNSResolverService RPCTimeout time.Duration WriteStreamTimeout time.Duration