cloudflared/connection/quic.go
Nuno Diegues 1086d5ede5 TUN-5204: Unregister QUIC transports on disconnect
This adds various bug fixes when investigating why QUIC transports were
not being unregistered when they should (and only when the graceful shutdown
started).

Most of these bug fixes are making the QUIC transport implementation closer
to its HTTP2 counterpart:
 - ServeControlStream is now a blocking function (it's up to the transport to handle that)
 - QUIC transport then handles the control plane as part of its Serve, thus waiting for it on shutdown
 - QUIC transport now returns "non recoverable" for connections with similar semantics to HTTP2 and H2mux
 - QUIC transport no longer has a loop around its Serve logic that retries connections on its own (that logic is upstream)
2022-01-06 10:08:38 +00:00

354 lines
12 KiB
Go

package connection
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress"
quicpogs "github.com/cloudflare/cloudflared/quic"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
// HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP.
HTTPHeaderKey = "HttpHeader"
// HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
HTTPMethodKey = "HttpMethod"
// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
HTTPHostKey = "HttpHost"
)
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct {
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
sessionManager datagramsession.Manager
controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions
}
// NewQUICConnection returns a new instance of QUICConnection.
func NewQUICConnection(
quicConfig *quic.Config,
edgeAddr net.Addr,
tlsConfig *tls.Config,
httpProxy OriginProxy,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger,
) (*QUICConnection, error) {
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial to edge: %w", err)
}
datagramMuxer, err := quicpogs.NewDatagramMuxer(session)
if err != nil {
return nil, err
}
sessionManager := datagramsession.NewManager(datagramMuxer, logger)
return &QUICConnection{
session: session,
httpProxy: httpProxy,
logger: logger,
sessionManager: sessionManager,
controlStreamHandler: controlStreamHandler,
connOptions: connOptions,
}, nil
}
// Serve starts a QUIC session that begins accepting streams.
func (q *QUICConnection) Serve(ctx context.Context) error {
// origintunneld assumes the first stream is used for the control plane
controlStream, err := q.session.OpenStream()
if err != nil {
return fmt.Errorf("failed to open a registration control stream: %w", err)
}
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
// connection).
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
// other goroutine as fast as possible.
ctx, cancel := context.WithCancel(ctx)
errGroup, ctx := errgroup.WithContext(ctx)
// In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
// stream is already fully registered before the other goroutines can proceed.
errGroup.Go(func() error {
defer cancel()
return q.serveControlStream(ctx, controlStream)
})
errGroup.Go(func() error {
defer cancel()
return q.acceptStream(ctx)
})
errGroup.Go(func() error {
defer cancel()
return q.sessionManager.Serve(ctx)
})
return errGroup.Wait()
}
func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
// This blocks until the control plane is done.
err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions)
if err != nil {
// Not wrapping error here to be consistent with the http2 message.
return err
}
return nil
}
func (q *QUICConnection) acceptStream(ctx context.Context) error {
defer q.Close()
for {
stream, err := q.session.AcceptStream(ctx)
if err != nil {
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
return nil
}
return fmt.Errorf("failed to accept QUIC stream: %w", err)
}
go func() {
defer stream.Close()
if err = q.handleStream(stream); err != nil {
q.logger.Err(err).Msg("Failed to handle QUIC stream")
}
}()
}
}
// Close closes the session with no errors specified.
func (q *QUICConnection) Close() {
q.session.CloseWithError(0, "")
}
func (q *QUICConnection) handleStream(stream quic.Stream) error {
signature, err := quicpogs.DetermineProtocol(stream)
if err != nil {
return err
}
switch signature {
case quicpogs.DataStreamProtocolSignature:
reqServerStream, err := quicpogs.NewRequestServerStream(stream, signature)
if err != nil {
return nil
}
return q.handleDataStream(reqServerStream)
case quicpogs.RPCStreamProtocolSignature:
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
if err != nil {
return err
}
return q.handleRPCStream(rpcStream)
default:
return fmt.Errorf("unknown protocol %v", signature)
}
}
func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error {
connectRequest, err := stream.ReadConnectRequestData()
if err != nil {
return err
}
switch connectRequest.Type {
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
req, err := buildHTTPRequest(connectRequest, stream)
if err != nil {
return err
}
w := newHTTPResponseAdapter(stream)
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{stream}
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
}
return nil
}
func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
return rpcStream.Serve(q, q.logger)
}
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration) error {
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := ingress.DialUDP(dstIP, dstPort)
if err != nil {
q.logger.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
return err
}
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
if err != nil {
q.logger.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session")
return err
}
go q.serveUDPSession(session, closeAfterIdleHint)
q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort)
return nil
}
func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) {
ctx := q.session.Context()
closedByRemote, err := session.Serve(ctx, closeAfterIdleHint)
// If session is terminated by remote, then we know it has been unregistered from session manager and edge
if !closedByRemote {
if err != nil {
q.closeUDPSession(ctx, session.ID, err.Error())
} else {
q.closeUDPSession(ctx, session.ID, "terminated without error")
}
}
q.logger.Debug().Err(err).Str("sessionID", session.ID.String()).Msg("Session terminated")
}
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
stream, err := q.session.OpenStream()
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
// with edge
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).
Msgf("Failed to open quic stream to unregister udp session with edge")
return
}
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger)
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
// with edge
q.logger.Err(err).Str("sessionID", sessionID.String()).
Msgf("Failed to open rpc stream to unregister udp session with edge")
return
}
if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil {
q.logger.Err(err).Str("sessionID", sessionID.String()).
Msgf("Failed to unregister udp session with edge")
}
}
// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion
func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
return q.sessionManager.UnregisterSession(ctx, sessionID, message, true)
}
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client.
type streamReadWriteAcker struct {
*quicpogs.RequestServerStream
}
// AckConnection acks response back to the proxy.
func (s *streamReadWriteAcker) AckConnection() error {
return s.WriteConnectResponseData(nil)
}
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
type httpResponseAdapter struct {
*quicpogs.RequestServerStream
}
func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
return httpResponseAdapter{s}
}
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
metadata := make([]quicpogs.Metadata, 0)
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
for k, vv := range header {
for _, v := range vv {
httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
}
}
return hrw.WriteConnectResponseData(nil, metadata...)
}
func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
}
func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*http.Request, error) {
metadata := connectRequest.MetadataMap()
dest := connectRequest.Dest
method := metadata[HTTPMethodKey]
host := metadata[HTTPHostKey]
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
req, err := http.NewRequest(method, dest, body)
if err != nil {
return nil, err
}
req.Host = host
for _, metadata := range connectRequest.Metadata {
if strings.Contains(metadata.Key, HTTPHeaderKey) {
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
httpHeaderKey := strings.Split(metadata.Key, ":")
if len(httpHeaderKey) != 2 {
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
}
req.Header.Add(httpHeaderKey[1], metadata.Val)
}
}
// Go's http.Client automatically sends chunked request body if this value is not set on the
// *http.Request struct regardless of header:
// https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
if err := setContentLength(req); err != nil {
return nil, fmt.Errorf("Error setting content-length: %w", err)
}
// Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
// * the request body blocks
// * the content length is not set (or set to -1)
// * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
// * there is no transfer-encoding=chunked already set.
// So, if transfer cannot be chunked and content length is 0, we dont set a request body.
if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
req.Body = nil
}
stripWebsocketUpgradeHeader(req)
return req, err
}
func setContentLength(req *http.Request) error {
var err error
if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
}
return err
}
func isTransferEncodingChunked(req *http.Request) bool {
transferEncodingVal := req.Header.Get("Transfer-Encoding")
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
// separated value as well.
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
}