mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-2637: Manage edge IPs in a region-aware manner
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -41,11 +40,9 @@ var (
|
||||
)
|
||||
|
||||
type Supervisor struct {
|
||||
cloudflaredUUID uuid.UUID
|
||||
config *TunnelConfig
|
||||
edgeIPs []*net.TCPAddr
|
||||
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
||||
nextUnusedEdgeIP int
|
||||
cloudflaredUUID uuid.UUID
|
||||
config *TunnelConfig
|
||||
edgeIPs connection.EdgeServiceDiscoverer
|
||||
lastResolve time.Time
|
||||
resolverC chan resolveResult
|
||||
tunnelErrors chan tunnelError
|
||||
@@ -65,25 +62,30 @@ type Supervisor struct {
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
edgeIPs []*net.TCPAddr
|
||||
err error
|
||||
err error
|
||||
}
|
||||
|
||||
type tunnelError struct {
|
||||
index int
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
|
||||
func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
|
||||
edgeIPs, err := connection.NewEdgeAddrResolver(config.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Supervisor{
|
||||
cloudflaredUUID: u,
|
||||
config: config,
|
||||
edgeIPs: edgeIPs,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
logger: config.Logger.WithField("subsystem", "supervisor"),
|
||||
jwtLock: &sync.RWMutex{},
|
||||
eventDigestLock: &sync.RWMutex{},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
@@ -134,8 +136,8 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||
// If the error is a dial error, the problem is likely to be network related
|
||||
// try another addr before refreshing since we are likely to get back the
|
||||
// same IPs in the same order. Same problem with duplicate connection error.
|
||||
if s.unusedIPs() {
|
||||
s.replaceEdgeIP(tunnelError.index)
|
||||
if s.unusedIPs() && tunnelError.addr != nil {
|
||||
s.edgeIPs.MarkAddrBad(tunnelError.addr)
|
||||
} else {
|
||||
s.refreshEdgeIPs()
|
||||
}
|
||||
@@ -170,7 +172,6 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||
s.resolverC = nil
|
||||
if result.err == nil {
|
||||
logger.Debug("Service discovery refresh complete")
|
||||
s.edgeIPs = result.edgeIPs
|
||||
} else {
|
||||
logger.WithError(result.err).Error("Service discovery error")
|
||||
}
|
||||
@@ -182,19 +183,18 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
logger := s.logger
|
||||
|
||||
edgeIPs, err := s.resolveEdgeIPs()
|
||||
err := s.edgeIPs.Refresh()
|
||||
if err != nil {
|
||||
logger.Infof("ResolveEdgeIPs err")
|
||||
return err
|
||||
}
|
||||
s.edgeIPs = edgeIPs
|
||||
if s.config.HAConnections > len(edgeIPs) {
|
||||
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
||||
s.config.HAConnections = len(edgeIPs)
|
||||
}
|
||||
s.lastResolve = time.Now()
|
||||
// check entitlement and version too old error before attempting to register more tunnels
|
||||
s.nextUnusedEdgeIP = s.config.HAConnections
|
||||
availableAddrs := int(s.edgeIPs.AvailableAddrs())
|
||||
if s.config.HAConnections > availableAddrs {
|
||||
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
||||
s.config.HAConnections = availableAddrs
|
||||
}
|
||||
|
||||
go s.startFirstTunnel(ctx, connectedSignal)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -216,16 +216,24 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
||||
err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
)
|
||||
defer func() {
|
||||
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
||||
s.tunnelErrors <- tunnelError{index: 0, addr: addr, err: err}
|
||||
}()
|
||||
|
||||
addr, err = s.edgeIPs.Addr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, 0, connectedSignal, s.cloudflaredUUID)
|
||||
|
||||
for s.unusedIPs() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
default:
|
||||
}
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
@@ -233,19 +241,34 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
// try the next address if it was a dialError(network problem) or
|
||||
// dupConnRegisterTunnelError
|
||||
case connection.DialError, dupConnRegisterTunnelError:
|
||||
s.replaceEdgeIP(0)
|
||||
s.edgeIPs.MarkAddrBad(addr)
|
||||
default:
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||
addr, err = s.edgeIPs.Addr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, 0, connectedSignal, s.cloudflaredUUID)
|
||||
}
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
||||
err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
|
||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
)
|
||||
defer func() {
|
||||
s.tunnelErrors <- tunnelError{index: index, addr: addr, err: err}
|
||||
}()
|
||||
|
||||
addr, err = s.edgeIPs.Addr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID)
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||
@@ -267,17 +290,8 @@ func (s *Supervisor) waitForNextTunnel(index int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
|
||||
return s.edgeIPs[index%len(s.edgeIPs)]
|
||||
}
|
||||
|
||||
func (s *Supervisor) resolveEdgeIPs() ([]*net.TCPAddr, error) {
|
||||
// If --edge is specfied, resolve edge server addresses
|
||||
if len(s.config.EdgeAddrs) > 0 {
|
||||
return connection.ResolveAddrs(s.config.EdgeAddrs)
|
||||
}
|
||||
// Otherwise lookup edge server addresses through service discovery
|
||||
return connection.EdgeDiscovery(s.logger)
|
||||
func (s *Supervisor) unusedIPs() bool {
|
||||
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
|
||||
}
|
||||
|
||||
func (s *Supervisor) refreshEdgeIPs() {
|
||||
@@ -289,20 +303,11 @@ func (s *Supervisor) refreshEdgeIPs() {
|
||||
}
|
||||
s.resolverC = make(chan resolveResult)
|
||||
go func() {
|
||||
edgeIPs, err := s.resolveEdgeIPs()
|
||||
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
|
||||
err := s.edgeIPs.Refresh()
|
||||
s.resolverC <- resolveResult{err: err}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Supervisor) unusedIPs() bool {
|
||||
return s.nextUnusedEdgeIP < len(s.edgeIPs)
|
||||
}
|
||||
|
||||
func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
|
||||
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
||||
s.nextUnusedEdgeIP++
|
||||
}
|
||||
|
||||
func (s *Supervisor) ReconnectToken() ([]byte, error) {
|
||||
s.jwtLock.RLock()
|
||||
defer s.jwtLock.RUnlock()
|
||||
@@ -366,7 +371,11 @@ func (s *Supervisor) refreshAuth(
|
||||
}
|
||||
|
||||
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
||||
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
|
||||
arbitraryEdgeIP, err := s.edgeIPs.AnyAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -24,7 +24,10 @@ func TestRefreshAuthBackoff(t *testing.T) {
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return nil, fmt.Errorf("authentication failure")
|
||||
@@ -66,7 +69,10 @@ func TestRefreshAuthSuccess(t *testing.T) {
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||
@@ -92,7 +98,10 @@ func TestRefreshAuthUnknown(t *testing.T) {
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||
@@ -112,7 +121,10 @@ func TestRefreshAuthFail(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
logger.Level = logrus.ErrorLevel
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||
|
@@ -156,7 +156,11 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
||||
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
|
||||
s, err := NewSupervisor(config, cloudflaredID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Run(ctx, connectedSignal)
|
||||
}
|
||||
|
||||
func ServeTunnelLoop(ctx context.Context,
|
||||
|
Reference in New Issue
Block a user