cloudflared/connection/h2mux.go
Igor Postelnik 8ca0d86c85 TUN-3863: Consolidate header handling logic in the connection package; move headers definitions from h2mux to packages that manage them; cleanup header conversions
All header transformation code from h2mux has been consolidated in the connection package since it's used by both h2mux and http2 logic.
Exported headers used by proxying between edge and cloudflared so then can be shared by tunnel service on the edge.
Moved access-related headers to corresponding packages that have the code that sets/uses these headers.
Removed tunnel hostname tracking from h2mux since it wasn't used by anything. We will continue to set the tunnel hostname header from the edge for backward compatibilty, but it's no longer used by cloudflared.
Move bastion-related logic into carrier package, untangled dependencies between carrier, origin, and websocket packages.
2021-03-29 21:57:56 +00:00

261 lines
7.3 KiB
Go

package connection
import (
"context"
"io"
"net"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
)
const (
muxerTimeout = 5 * time.Second
openStreamTimeout = 30 * time.Second
)
type h2muxConnection struct {
config *Config
muxerConfig *MuxerConfig
muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string
connIndexStr string
connIndex uint8
observer *Observer
gracefulShutdownC <-chan struct{}
stoppedGracefully bool
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
}
type MuxerConfig struct {
HeartbeatInterval time.Duration
MaxHeartbeats uint64
CompressionSetting h2mux.CompressionSetting
MetricsUpdateFreq time.Duration
}
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Logger) *h2mux.MuxerConfig {
return &h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: h,
IsClient: true,
HeartbeatInterval: mc.HeartbeatInterval,
MaxHeartbeats: mc.MaxHeartbeats,
Log: log,
CompressionQuality: mc.CompressionSetting,
}
}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection(
config *Config,
muxerConfig *MuxerConfig,
edgeConn net.Conn,
connIndex uint8,
observer *Observer,
gracefulShutdownC <-chan struct{},
) (*h2muxConnection, error, bool) {
h := &h2muxConnection{
config: config,
muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
observer: observer,
gracefulShutdownC: gracefulShutdownC,
newRPCClientFunc: newRegistrationRPCClient,
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer.logTransport), h2mux.ActiveStreams)
if err != nil {
recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
return nil, err, recoverable
}
h.muxer = muxer
return h, nil, false
}
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() error {
if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
return err
}
connectedFuse.Connected()
return nil
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, true)
return nil
})
err := errGroup.Wait()
if err == errMuxerStopped {
if h.stoppedGracefully {
return nil
}
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
}
return err
}
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Connected()
}
}()
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
if err == nil {
return nil
}
// log errors and proceed to RegisterTunnel
h.observer.log.Err(err).
Uint8(LogFieldConnIndex, h.connIndex).
Msg("Couldn't reconnect connection. Re-registering it instead.")
}
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, false)
return nil
})
err := errGroup.Wait()
if err == errMuxerStopped {
if h.stoppedGracefully {
return nil
}
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
}
return err
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
// here to notify other routines to stop
err := h.muxer.Serve(ctx)
if err == nil {
return errMuxerStopped
}
return err
}
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
var shutdownCompleted <-chan struct{}
for {
select {
case <-h.gracefulShutdownC:
if connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.stoppedGracefully = true
h.gracefulShutdownC = nil
shutdownCompleted = h.muxer.Shutdown()
case <-shutdownCompleted:
return
case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns
if !h.stoppedGracefully && connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.muxer.Shutdown()
// don't wait for shutdown to finish when context is closed, this is the hard termination path
return
case <-updateMetricsTickC:
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
}
}
}
func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := h.muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return nil, err
}
return stream, nil
}
func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
respWriter := &h2muxRespWriter{stream}
req, reqErr := h.newRequest(stream)
if reqErr != nil {
respWriter.WriteErrorResponse()
return reqErr
}
var sourceConnectionType = TypeHTTP
if websocket.IsWebSocketUpgrade(req) {
sourceConnectionType = TypeWebsocket
}
err := h.config.OriginProxy.Proxy(respWriter, req, sourceConnectionType)
if err != nil {
respWriter.WriteErrorResponse()
return err
}
return nil
}
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
return req, nil
}
type h2muxRespWriter struct {
*h2mux.MuxedStream
}
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
headers := H1ResponseToH2ResponseHeaders(status, header)
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
return rp.WriteHeaders(headers)
}
func (rp *h2muxRespWriter) WriteErrorResponse() {
_ = rp.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"},
{Name: ResponseMetaHeader, Value: responseMetaHeaderCfd},
})
_, _ = rp.Write([]byte("502 Bad Gateway"))
}