TUN-2637: Manage edge IPs in a region-aware manner

This commit is contained in:
Nick Vollmar
2019-12-13 17:05:21 -06:00
parent 87102a2646
commit 7e31b77646
10 changed files with 1011 additions and 206 deletions

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
@@ -41,11 +40,9 @@ var (
)
type Supervisor struct {
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs []*net.TCPAddr
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
nextUnusedEdgeIP int
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs connection.EdgeServiceDiscoverer
lastResolve time.Time
resolverC chan resolveResult
tunnelErrors chan tunnelError
@@ -65,25 +62,30 @@ type Supervisor struct {
}
type resolveResult struct {
edgeIPs []*net.TCPAddr
err error
err error
}
type tunnelError struct {
index int
addr *net.TCPAddr
err error
}
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
edgeIPs, err := connection.NewEdgeAddrResolver(config.Logger)
if err != nil {
return nil, err
}
return &Supervisor{
cloudflaredUUID: u,
config: config,
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger.WithField("subsystem", "supervisor"),
jwtLock: &sync.RWMutex{},
eventDigestLock: &sync.RWMutex{},
}
}, nil
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
@@ -134,8 +136,8 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
// If the error is a dial error, the problem is likely to be network related
// try another addr before refreshing since we are likely to get back the
// same IPs in the same order. Same problem with duplicate connection error.
if s.unusedIPs() {
s.replaceEdgeIP(tunnelError.index)
if s.unusedIPs() && tunnelError.addr != nil {
s.edgeIPs.MarkAddrBad(tunnelError.addr)
} else {
s.refreshEdgeIPs()
}
@@ -170,7 +172,6 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
s.resolverC = nil
if result.err == nil {
logger.Debug("Service discovery refresh complete")
s.edgeIPs = result.edgeIPs
} else {
logger.WithError(result.err).Error("Service discovery error")
}
@@ -182,19 +183,18 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.logger
edgeIPs, err := s.resolveEdgeIPs()
err := s.edgeIPs.Refresh()
if err != nil {
logger.Infof("ResolveEdgeIPs err")
return err
}
s.edgeIPs = edgeIPs
if s.config.HAConnections > len(edgeIPs) {
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
s.config.HAConnections = len(edgeIPs)
}
s.lastResolve = time.Now()
// check entitlement and version too old error before attempting to register more tunnels
s.nextUnusedEdgeIP = s.config.HAConnections
availableAddrs := int(s.edgeIPs.AvailableAddrs())
if s.config.HAConnections > availableAddrs {
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs
}
go s.startFirstTunnel(ctx, connectedSignal)
select {
case <-ctx.Done():
@@ -216,16 +216,24 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
var (
addr *net.TCPAddr
err error
)
defer func() {
s.tunnelErrors <- tunnelError{index: 0, err: err}
s.tunnelErrors <- tunnelError{index: 0, addr: addr, err: err}
}()
addr, err = s.edgeIPs.Addr()
if err != nil {
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, 0, connectedSignal, s.cloudflaredUUID)
for s.unusedIPs() {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
default:
}
switch err.(type) {
case nil:
@@ -233,19 +241,34 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
// try the next address if it was a dialError(network problem) or
// dupConnRegisterTunnelError
case connection.DialError, dupConnRegisterTunnelError:
s.replaceEdgeIP(0)
s.edgeIPs.MarkAddrBad(addr)
default:
return
}
err = ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
addr, err = s.edgeIPs.Addr()
if err != nil {
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, 0, connectedSignal, s.cloudflaredUUID)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
s.tunnelErrors <- tunnelError{index: index, err: err}
var (
addr *net.TCPAddr
err error
)
defer func() {
s.tunnelErrors <- tunnelError{index: index, addr: addr, err: err}
}()
addr, err = s.edgeIPs.Addr()
if err != nil {
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID)
}
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
@@ -267,17 +290,8 @@ func (s *Supervisor) waitForNextTunnel(index int) bool {
return false
}
func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
return s.edgeIPs[index%len(s.edgeIPs)]
}
func (s *Supervisor) resolveEdgeIPs() ([]*net.TCPAddr, error) {
// If --edge is specfied, resolve edge server addresses
if len(s.config.EdgeAddrs) > 0 {
return connection.ResolveAddrs(s.config.EdgeAddrs)
}
// Otherwise lookup edge server addresses through service discovery
return connection.EdgeDiscovery(s.logger)
func (s *Supervisor) unusedIPs() bool {
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
}
func (s *Supervisor) refreshEdgeIPs() {
@@ -289,20 +303,11 @@ func (s *Supervisor) refreshEdgeIPs() {
}
s.resolverC = make(chan resolveResult)
go func() {
edgeIPs, err := s.resolveEdgeIPs()
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
err := s.edgeIPs.Refresh()
s.resolverC <- resolveResult{err: err}
}()
}
func (s *Supervisor) unusedIPs() bool {
return s.nextUnusedEdgeIP < len(s.edgeIPs)
}
func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
s.nextUnusedEdgeIP++
}
func (s *Supervisor) ReconnectToken() ([]byte, error) {
s.jwtLock.RLock()
defer s.jwtLock.RUnlock()
@@ -366,7 +371,11 @@ func (s *Supervisor) refreshAuth(
}
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
arbitraryEdgeIP, err := s.edgeIPs.AnyAddr()
if err != nil {
return nil, err
}
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
if err != nil {
return nil, err

View File

@@ -24,7 +24,10 @@ func TestRefreshAuthBackoff(t *testing.T) {
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure")
@@ -66,7 +69,10 @@ func TestRefreshAuthSuccess(t *testing.T) {
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
@@ -92,7 +98,10 @@ func TestRefreshAuthUnknown(t *testing.T) {
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
@@ -112,7 +121,10 @@ func TestRefreshAuthFail(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
s, err := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil

View File

@@ -156,7 +156,11 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
s, err := NewSupervisor(config, cloudflaredID)
if err != nil {
return err
}
return s.Run(ctx, connectedSignal)
}
func ServeTunnelLoop(ctx context.Context,