TUN-2819: cloudflared should close its connections when a signal is sent

This commit is contained in:
Adam Chalmers
2020-03-19 10:38:28 -05:00
parent 96f11de7ab
commit 6dcf3a4cbc
3 changed files with 30 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"os"
"sync"
"time"
@@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
}, nil
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal); err != nil {
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err
}
var tunnelsWaiting []int
@@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
@@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
}
// Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
logger := s.logger
s.lastResolve = time.Now()
@@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
s.config.HAConnections = availableAddrs
}
go s.startFirstTunnel(ctx, connectedSignal)
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
select {
case <-ctx.Done():
<-s.tunnelErrors
@@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch)
go s.startTunnel(ctx, i, ch, reconnectCh)
time.Sleep(registrationInterval)
}
return nil
@@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
var (
addr *net.TCPAddr
err error
@@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it.
edgeErrors := 0
for s.unusedIPs() {
@@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
}
}
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
var (
addr *net.TCPAddr
err error
@@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil {
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool)
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
}
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {