mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 16:49:57 +00:00
TUN-9470: Add OriginDialerService to include TCP
Adds an OriginDialerService that takes over the roles of both DialUDP and DialTCP towards the origin. This provides the possibility to leverage dialer "middleware" to inject virtual origins, such as the DNS resolver service. DNS Resolver service also gains access to the DialTCP operation to service TCP DNS requests. Minor refactoring includes changes to remove the needs previously provided by the warp-routing configuration. This configuration cannot be disabled by cloudflared so many of the references have been adjusted or removed. Closes TUN-9470
This commit is contained in:
@@ -19,7 +19,7 @@ import (
|
||||
type OriginConnection interface {
|
||||
// Stream should generally be implemented as a bidirectional io.Copy.
|
||||
Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger)
|
||||
Close()
|
||||
Close() error
|
||||
}
|
||||
|
||||
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
|
||||
@@ -48,16 +48,7 @@ func (tc *tcpConnection) Write(b []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
nBytes, err := tc.Conn.Write(b)
|
||||
if err != nil {
|
||||
tc.logger.Err(err).Msg("Error writing to the TCP connection")
|
||||
}
|
||||
|
||||
return nBytes, err
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Close() {
|
||||
tc.Conn.Close()
|
||||
return tc.Conn.Write(b)
|
||||
}
|
||||
|
||||
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
|
||||
@@ -75,8 +66,8 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
|
||||
wsConn.Close()
|
||||
}
|
||||
|
||||
func (wc *tcpOverWSConnection) Close() {
|
||||
wc.conn.Close()
|
||||
func (wc *tcpOverWSConnection) Close() error {
|
||||
return wc.conn.Close()
|
||||
}
|
||||
|
||||
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
|
||||
@@ -95,5 +86,6 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
|
||||
wsConn.Close()
|
||||
}
|
||||
|
||||
func (sp *socksProxyOverWSConnection) Close() {
|
||||
func (sp *socksProxyOverWSConnection) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
146
ingress/origin_dialer.go
Normal file
146
ingress/origin_dialer.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// OriginTCPDialer provides a TCP dial operation to a requested address.
|
||||
type OriginTCPDialer interface {
|
||||
DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
|
||||
}
|
||||
|
||||
// OriginUDPDialer provides a UDP dial operation to a requested address.
|
||||
type OriginUDPDialer interface {
|
||||
DialUDP(addr netip.AddrPort) (net.Conn, error)
|
||||
}
|
||||
|
||||
// OriginDialer provides both TCP and UDP dial operations to an address.
|
||||
type OriginDialer interface {
|
||||
OriginTCPDialer
|
||||
OriginUDPDialer
|
||||
}
|
||||
|
||||
type OriginConfig struct {
|
||||
// The default Dialer used if no reserved services are found for an origin request.
|
||||
DefaultDialer OriginDialer
|
||||
// Timeout on write operations for TCP connections to the origin.
|
||||
TCPWriteTimeout time.Duration
|
||||
}
|
||||
|
||||
// OriginDialerService provides a proxy TCP and UDP dialer to origin services while allowing reserved
|
||||
// services to be provided. These reserved services are assigned to specific [netip.AddrPort]s
|
||||
// and provide their own [OriginDialer]'s to handle origin dialing per protocol.
|
||||
type OriginDialerService struct {
|
||||
// Reserved TCP services for reserved AddrPort values
|
||||
reservedTCPServices map[netip.AddrPort]OriginTCPDialer
|
||||
// Reserved UDP services for reserved AddrPort values
|
||||
reservedUDPServices map[netip.AddrPort]OriginUDPDialer
|
||||
// The default Dialer used if no reserved services are found for an origin request
|
||||
defaultDialer OriginDialer
|
||||
defaultDialerM sync.RWMutex
|
||||
// Write timeout for TCP connections
|
||||
writeTimeout time.Duration
|
||||
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewOriginDialer(config OriginConfig, logger *zerolog.Logger) *OriginDialerService {
|
||||
return &OriginDialerService{
|
||||
reservedTCPServices: map[netip.AddrPort]OriginTCPDialer{},
|
||||
reservedUDPServices: map[netip.AddrPort]OriginUDPDialer{},
|
||||
defaultDialer: config.DefaultDialer,
|
||||
writeTimeout: config.TCPWriteTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AddReservedService adds a reserved virtual service to dial to.
|
||||
// Not locked and expected to be initialized before calling first dial and not afterwards.
|
||||
func (d *OriginDialerService) AddReservedService(service OriginDialer, addrs []netip.AddrPort) {
|
||||
for _, addr := range addrs {
|
||||
d.reservedTCPServices[addr] = service
|
||||
d.reservedUDPServices[addr] = service
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateDefaultDialer updates the default dialer.
|
||||
func (d *OriginDialerService) UpdateDefaultDialer(dialer *Dialer) {
|
||||
d.defaultDialerM.Lock()
|
||||
defer d.defaultDialerM.Unlock()
|
||||
d.defaultDialer = dialer
|
||||
}
|
||||
|
||||
// DialTCP will perform a dial TCP to the requested addr.
|
||||
func (d *OriginDialerService) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) {
|
||||
conn, err := d.dialTCP(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Assign the write timeout for the TCP operations
|
||||
return &tcpConnection{
|
||||
Conn: conn,
|
||||
writeTimeout: d.writeTimeout,
|
||||
logger: d.logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *OriginDialerService) dialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) {
|
||||
// Check to see if any reserved services are available for this addr and call their dialer instead.
|
||||
if dialer, ok := d.reservedTCPServices[addr]; ok {
|
||||
return dialer.DialTCP(ctx, addr)
|
||||
}
|
||||
d.defaultDialerM.RLock()
|
||||
dialer := d.defaultDialer
|
||||
d.defaultDialerM.RUnlock()
|
||||
return dialer.DialTCP(ctx, addr)
|
||||
}
|
||||
|
||||
// DialUDP will perform a dial UDP to the requested addr.
|
||||
func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) {
|
||||
// Check to see if any reserved services are available for this addr and call their dialer instead.
|
||||
if dialer, ok := d.reservedUDPServices[addr]; ok {
|
||||
return dialer.DialUDP(addr)
|
||||
}
|
||||
d.defaultDialerM.RLock()
|
||||
dialer := d.defaultDialer
|
||||
d.defaultDialerM.RUnlock()
|
||||
return dialer.DialUDP(addr)
|
||||
}
|
||||
|
||||
type Dialer struct {
|
||||
Dialer net.Dialer
|
||||
}
|
||||
|
||||
func NewDialer(config WarpRoutingConfig) *Dialer {
|
||||
return &Dialer{
|
||||
Dialer: net.Dialer{
|
||||
Timeout: config.ConnectTimeout.Duration,
|
||||
KeepAlive: config.TCPKeepAlive.Duration,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (d *Dialer) DialUDP(dest netip.AddrPort) (net.Conn, error) {
|
||||
conn, err := d.Dialer.Dial("udp", dest.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
@@ -1,66 +0,0 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// UDPOriginService provides a proxy UDP dialer to origin services while allowing reserved
|
||||
// services to be provided. These reserved services are assigned to specific [netip.AddrPort]s
|
||||
// and provide their own [UDPOriginProxy]s to handle UDP origin dialing.
|
||||
type UDPOriginService struct {
|
||||
// Reserved services for reserved AddrPort values
|
||||
reservedServices map[netip.AddrPort]UDPOriginProxy
|
||||
// The default UDP Dialer used if no reserved services are found for an origin request.
|
||||
defaultDialer UDPOriginProxy
|
||||
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
// UDPOriginProxy provides a UDP dial operation to a requested addr.
|
||||
type UDPOriginProxy interface {
|
||||
DialUDP(addr netip.AddrPort) (net.Conn, error)
|
||||
}
|
||||
|
||||
func NewUDPOriginService(reserved map[netip.AddrPort]UDPOriginProxy, logger *zerolog.Logger) *UDPOriginService {
|
||||
return &UDPOriginService{
|
||||
reservedServices: reserved,
|
||||
defaultDialer: DefaultUDPDialer,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetUDPDialer updates the default UDP Dialer used.
|
||||
// Typically used in unit testing.
|
||||
func (s *UDPOriginService) SetDefaultDialer(dialer UDPOriginProxy) {
|
||||
s.defaultDialer = dialer
|
||||
}
|
||||
|
||||
// DialUDP will perform a dial UDP to the requested addr.
|
||||
func (s *UDPOriginService) DialUDP(addr netip.AddrPort) (net.Conn, error) {
|
||||
// Check to see if any reserved services are available for this addr and call their dialer instead.
|
||||
if dialer, ok := s.reservedServices[addr]; ok {
|
||||
return dialer.DialUDP(addr)
|
||||
}
|
||||
return s.defaultDialer.DialUDP(addr)
|
||||
}
|
||||
|
||||
type defaultUDPDialer struct{}
|
||||
|
||||
var DefaultUDPDialer UDPOriginProxy = &defaultUDPDialer{}
|
||||
|
||||
func (d *defaultUDPDialer) DialUDP(dest netip.AddrPort) (net.Conn, error) {
|
||||
addr := net.UDPAddrFromAddrPort(dest)
|
||||
|
||||
// We use nil as local addr to force runtime to find the best suitable local address IP given the destination
|
||||
// address as context.
|
||||
udpConn, err := net.DialUDP("udp", nil, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
@@ -45,20 +45,28 @@ type DNSResolverService struct {
|
||||
address netip.AddrPort
|
||||
addressM sync.RWMutex
|
||||
|
||||
dialer ingress.UDPOriginProxy
|
||||
dialer ingress.OriginDialer
|
||||
resolver peekResolver
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewDNSResolver(logger *zerolog.Logger) *DNSResolverService {
|
||||
func NewDNSResolverService(dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService {
|
||||
return &DNSResolverService{
|
||||
address: defaultResolverAddr,
|
||||
dialer: ingress.DefaultUDPDialer,
|
||||
dialer: dialer,
|
||||
resolver: &resolver{dialFunc: net.Dial},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DNSResolverService) DialTCP(ctx context.Context, _ netip.AddrPort) (net.Conn, error) {
|
||||
s.addressM.RLock()
|
||||
dest := s.address
|
||||
s.addressM.RUnlock()
|
||||
// The dialer ignores the provided address because the request will instead go to the local DNS resolver.
|
||||
return s.dialer.DialTCP(ctx, dest)
|
||||
}
|
||||
|
||||
func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) {
|
||||
s.addressM.RLock()
|
||||
dest := s.address
|
||||
@@ -155,3 +163,18 @@ func (r *resolver) peekDial(ctx context.Context, network, address string) (net.C
|
||||
r.address = address
|
||||
return r.dialFunc(network, address)
|
||||
}
|
||||
|
||||
// NewDNSDialer creates a custom dialer for the DNS resolver service to utilize.
|
||||
func NewDNSDialer() *ingress.Dialer {
|
||||
return &ingress.Dialer{
|
||||
Dialer: net.Dialer{
|
||||
// We want short timeouts for the DNS requests
|
||||
Timeout: 5 * time.Second,
|
||||
// We do not want keep alive since the edge will not reuse TCP connections per request
|
||||
KeepAlive: -1,
|
||||
KeepAliveConfig: net.KeepAliveConfig{
|
||||
Enable: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func TestDNSResolver_DefaultResolver(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
service := NewDNSResolver(&log)
|
||||
service := NewDNSResolverService(NewDNSDialer(), &log)
|
||||
mockResolver := &mockPeekResolver{
|
||||
address: "127.0.0.2:53",
|
||||
}
|
||||
@@ -24,7 +24,7 @@ func TestDNSResolver_DefaultResolver(t *testing.T) {
|
||||
|
||||
func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
service := NewDNSResolver(&log)
|
||||
service := NewDNSResolverService(NewDNSDialer(), &log)
|
||||
|
||||
mockResolver := &mockPeekResolver{}
|
||||
service.resolver = mockResolver
|
||||
@@ -51,7 +51,7 @@ func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
|
||||
|
||||
func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
service := NewDNSResolver(&log)
|
||||
service := NewDNSResolverService(NewDNSDialer(), &log)
|
||||
mockResolver := &mockPeekResolver{}
|
||||
service.resolver = mockResolver
|
||||
|
||||
@@ -77,7 +77,7 @@ func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
|
||||
|
||||
func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
service := NewDNSResolver(&log)
|
||||
service := NewDNSResolverService(NewDNSDialer(), &log)
|
||||
resolverErr := errors.New("test resolver error")
|
||||
mockResolver := &mockPeekResolver{err: resolverErr}
|
||||
service.resolver = mockResolver
|
||||
@@ -93,13 +93,12 @@ func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSResolver_DialUsesResolvedAddress(t *testing.T) {
|
||||
func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
service := NewDNSResolver(&log)
|
||||
mockDialer := &mockDialer{expected: defaultResolverAddr}
|
||||
service := NewDNSResolverService(mockDialer, &log)
|
||||
mockResolver := &mockPeekResolver{}
|
||||
service.resolver = mockResolver
|
||||
mockDialer := &mockDialer{expected: defaultResolverAddr}
|
||||
service.dialer = mockDialer
|
||||
|
||||
// Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53
|
||||
_, err := service.DialUDP(netip.MustParseAddrPort("127.0.0.2:53"))
|
||||
@@ -108,6 +107,20 @@ func TestDNSResolver_DialUsesResolvedAddress(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSResolver_DialTCPUsesResolvedAddress(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
mockDialer := &mockDialer{expected: defaultResolverAddr}
|
||||
service := NewDNSResolverService(mockDialer, &log)
|
||||
mockResolver := &mockPeekResolver{}
|
||||
service.resolver = mockResolver
|
||||
|
||||
// Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53
|
||||
_, err := service.DialTCP(t.Context(), netip.MustParseAddrPort("127.0.0.2:53"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
type mockPeekResolver struct {
|
||||
err error
|
||||
address string
|
||||
@@ -126,6 +139,13 @@ type mockDialer struct {
|
||||
expected netip.AddrPort
|
||||
}
|
||||
|
||||
func (d *mockDialer) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) {
|
||||
if d.expected != addr {
|
||||
return nil, errors.New("unexpected address dialed")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) {
|
||||
if d.expected != addr {
|
||||
return nil, errors.New("unexpected address dialed")
|
||||
|
Reference in New Issue
Block a user