TUN-3118: Changed graceful shutdown to immediately unregister tunnel from the edge, keep the connection open until the edge drops it or grace period expires

This commit is contained in:
Igor Postelnik
2021-01-20 13:41:09 -06:00
parent db0562c7b8
commit d503aeaf77
10 changed files with 295 additions and 80 deletions

View File

@@ -2,6 +2,7 @@ package connection
import (
"context"
"io"
"net"
"net/http"
"time"
@@ -28,7 +29,12 @@ type h2muxConnection struct {
connIndexStr string
connIndex uint8
observer *Observer
observer *Observer
gracefulShutdownC chan struct{}
stoppedGracefully bool
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
}
type MuxerConfig struct {
@@ -57,13 +63,16 @@ func NewH2muxConnection(
edgeConn net.Conn,
connIndex uint8,
observer *Observer,
gracefulShutdownC chan struct{},
) (*h2muxConnection, error, bool) {
h := &h2muxConnection{
config: config,
muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
observer: observer,
config: config,
muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
observer: observer,
gracefulShutdownC: gracefulShutdownC,
newRPCClientFunc: newRegistrationRPCClient,
}
// Establish a muxed connection with the edge
@@ -77,21 +86,14 @@ func NewH2muxConnection(
return h, nil, false
}
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() error {
stream, err := h.newRPCStream(serveCtx, register)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer.log)
defer rpcClient.Close()
if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
return err
}
connectedFuse.Connected()
@@ -137,6 +139,10 @@ func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel
return errGroup.Wait()
}
func (h *h2muxConnection) StoppedGracefully() bool {
return h.stoppedGracefully
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
@@ -152,13 +158,21 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
for {
select {
case <-h.gracefulShutdownC:
if connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.stoppedGracefully = true
h.gracefulShutdownC = nil
case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns
if connectedFuse.IsConnected() {
if !h.stoppedGracefully && connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.muxer.Shutdown()
return
case <-updateMetricsTickC:
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
}

View File

@@ -11,10 +11,12 @@ import (
"testing"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/gobwas/ws/wsutil"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/h2mux"
)
var (
@@ -32,13 +34,20 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
go func() {
edgeMuxConfig := h2mux.MuxerConfig{
Log: testObserver.log,
Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error {
// we only expect RPC traffic in client->edge direction, provide minimal support for mocking
require.True(t, stream.IsRPCStream())
return stream.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "200"},
})
}),
}
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
require.NoError(t, err)
edgeMuxChan <- edgeMux
}()
var connIndex = uint8(0)
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver)
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil)
require.NoError(t, err)
return h2muxConn, <-edgeMuxChan
}
@@ -168,6 +177,55 @@ func TestServeStreamWS(t *testing.T) {
wg.Wait()
}
func TestGracefulShutdownH2Mux(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h2muxConn, edgeMux := newH2MuxConnection(t)
shutdownC := make(chan struct{})
unregisteredC := make(chan struct{})
h2muxConn.gracefulShutdownC = shutdownC
h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient {
return &mockNamedTunnelRPCClient{
registered: nil,
unregistered: unregisteredC,
}
}
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
_ = edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
_ = h2muxConn.serveMuxer(ctx)
}()
go func() {
defer wg.Done()
h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true)
}()
time.Sleep(100 * time.Millisecond)
close(shutdownC)
select {
case <-unregisteredC:
break // ok
case <-time.Tick(time.Second):
assert.Fail(t, "timed out waiting for control loop to unregister")
}
cancel()
wg.Wait()
assert.True(t, h2muxConn.stoppedGracefully)
assert.Nil(t, h2muxConn.gracefulShutdownC)
}
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
for _, header := range stream.Headers {
if header.Name == name && header.Value == val {

View File

@@ -2,6 +2,7 @@ package connection
import (
"context"
"fmt"
"io"
"math"
"net"
@@ -22,6 +23,8 @@ const (
controlStreamUpgrade = "control-stream"
)
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
type http2Connection struct {
conn net.Conn
server *http2.Server
@@ -33,8 +36,10 @@ type http2Connection struct {
connIndex uint8
wg *sync.WaitGroup
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
connectedFuse ConnectedFuse
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
connectedFuse ConnectedFuse
gracefulShutdownC chan struct{}
stoppedGracefully bool
}
func NewHTTP2Connection(
@@ -45,25 +50,27 @@ func NewHTTP2Connection(
observer *Observer,
connIndex uint8,
connectedFuse ConnectedFuse,
gracefulShutdownC chan struct{},
) *http2Connection {
return &http2Connection{
conn: conn,
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
},
config: config,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
wg: &sync.WaitGroup{},
newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse,
config: config,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
wg: &sync.WaitGroup{},
newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse,
gracefulShutdownC: gracefulShutdownC,
}
}
func (c *http2Connection) Serve(ctx context.Context) {
func (c *http2Connection) Serve(ctx context.Context) error {
go func() {
<-ctx.Done()
c.close()
@@ -72,6 +79,11 @@ func (c *http2Connection) Serve(ctx context.Context) {
Context: ctx,
Handler: c,
})
if !c.stoppedGracefully {
return errEdgeConnectionClosed
}
return nil
}
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -106,6 +118,10 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (c *http2Connection) StoppedGracefully() bool {
return c.stoppedGracefully
}
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
defer rpcClient.Close()
@@ -115,8 +131,16 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht
}
c.connectedFuse.Connected()
<-ctx.Done()
// wait for connection termination or start of graceful shutdown
select {
case <-ctx.Done():
break
case <-c.gracefulShutdownC:
c.stoppedGracefully = true
}
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
return nil
}

View File

@@ -12,6 +12,8 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@@ -36,6 +38,7 @@ func newTestHTTP2Connection() (*http2Connection, net.Conn) {
testObserver,
connIndex,
mockConnectedFuse{},
nil,
), edgeConn
}
@@ -241,10 +244,64 @@ func TestServeControlStream(t *testing.T) {
<-rpcClientFactory.registered
cancel()
<-rpcClientFactory.unregistered
assert.False(t, http2Conn.stoppedGracefully)
wg.Wait()
}
func TestGracefulShutdownHTTP2(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
http2Conn.gracefulShutdownC = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
require.NoError(t, err)
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
select {
case <-rpcClientFactory.registered:
break //ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for registration")
}
// signal graceful shutdown
close(http2Conn.gracefulShutdownC)
select {
case <-rpcClientFactory.unregistered:
break //ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for unregistered signal")
}
assert.True(t, http2Conn.stoppedGracefully)
cancel()
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection()
@@ -281,6 +338,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
cancel()
wg.Wait()
}
func BenchmarkServeHTTPSimple(b *testing.B) {
test := testRequest{
name: "ok",

View File

@@ -272,17 +272,36 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe
return nil
}
func (h *h2muxConnection) registerNamedTunnel(
ctx context.Context,
namedTunnel *NamedTunnelConfig,
connOptions *tunnelpogs.ConnectionOptions,
) error {
stream, err := h.newRPCStream(ctx, register)
if err != nil {
return err
}
rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
defer rpcClient.Close()
if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err
}
return nil
}
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
defer cancel()
stream, err := h.newRPCStream(unregisterCtx, register)
stream, err := h.newRPCStream(unregisterCtx, unregister)
if err != nil {
return
}
defer stream.Close()
if isNamedTunnel {
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer.log)
rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
defer rpcClient.Close()
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
@@ -293,4 +312,6 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
// gracePeriod is encoded in int64 using capnproto
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
}
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")
}