mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:39:58 +00:00
TUN-2819: cloudflared should close its connections when a signal is sent
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
|
||||
logger := s.config.Logger
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||
return err
|
||||
}
|
||||
var tunnelsWaiting []int
|
||||
@@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||
case <-backoffTimer:
|
||||
backoffTimer = nil
|
||||
for _, index := range tunnelsWaiting {
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
|
||||
}
|
||||
tunnelsActive += len(tunnelsWaiting)
|
||||
tunnelsWaiting = nil
|
||||
@@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||
}
|
||||
|
||||
// Returns nil if initialization succeeded, else the initialization error.
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
|
||||
logger := s.logger
|
||||
|
||||
s.lastResolve = time.Now()
|
||||
@@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||
s.config.HAConnections = availableAddrs
|
||||
}
|
||||
|
||||
go s.startFirstTunnel(ctx, connectedSignal)
|
||||
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
<-s.tunnelErrors
|
||||
@@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
ch := signal.New(make(chan struct{}))
|
||||
go s.startTunnel(ctx, i, ch)
|
||||
go s.startTunnel(ctx, i, ch, reconnectCh)
|
||||
time.Sleep(registrationInterval)
|
||||
}
|
||||
return nil
|
||||
@@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||
|
||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
@@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
// If the first tunnel disconnects, keep restarting it.
|
||||
edgeErrors := 0
|
||||
for s.unusedIPs() {
|
||||
@@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
@@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||
|
Reference in New Issue
Block a user