mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:29:58 +00:00
TUN-3400: Use Go HTTP2 library as transport to connect with the edge
This commit is contained in:
14
origin/connection.go
Normal file
14
origin/connection.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called
|
||||
type persistentConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (pc *persistentConn) Close() error {
|
||||
return nil
|
||||
}
|
160
origin/server.go
Normal file
160
origin/server.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type cfdServer struct {
|
||||
httpServer *http2.Server
|
||||
originClient http.RoundTripper
|
||||
logger logger.Service
|
||||
originURL *url.URL
|
||||
connectionIndex string
|
||||
config *TunnelConfig
|
||||
}
|
||||
|
||||
func (c *cfdServer) serve(ctx context.Context, conn net.Conn) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
|
||||
Context: ctx,
|
||||
Handler: c,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cfdServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.config.Metrics.incrementRequests(c.connectionIndex)
|
||||
defer c.config.Metrics.decrementConcurrentRequests(c.connectionIndex)
|
||||
|
||||
cfRay := findCfRayHeader(r)
|
||||
lbProbe := isLBProbeRequest(r)
|
||||
c.logRequest(r, cfRay, lbProbe)
|
||||
|
||||
r.URL = c.originURL
|
||||
// TODO: TUN-3406 support websocket, event stream and WSGI servers.
|
||||
var resp *http.Response
|
||||
var err error
|
||||
if strings.ToLower(r.Header.Get("Cf-Int-Argo-Tunnel-Upgrade")) == "websocket" {
|
||||
resp, err = serveWebsocket(newWebsocketBody(w, r, c.logger), r, c.config.HTTPHostHeader, c.config.ClientTlsConfig)
|
||||
} else {
|
||||
resp, err = c.originClient.RoundTrip(r)
|
||||
}
|
||||
if err != nil {
|
||||
c.writeErrorResponse(w, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(w, resp.Body)
|
||||
if err != nil {
|
||||
c.logger.Errorf("Copy response error, err: %v", err)
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cfdServer) writeErrorResponse(w http.ResponseWriter, err error) {
|
||||
c.logger.Errorf("HTTP request error: %s", err)
|
||||
c.config.Metrics.incrementResponses(c.connectionIndex, "502")
|
||||
jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func (c *cfdServer) logRequest(r *http.Request, cfRay string, lbProbe bool) {
|
||||
logger := c.logger
|
||||
if cfRay != "" {
|
||||
logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||
} else if lbProbe {
|
||||
logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||
} else {
|
||||
logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
|
||||
}
|
||||
logger.Infof("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
|
||||
|
||||
if contentLen := r.ContentLength; contentLen == -1 {
|
||||
logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
|
||||
} else {
|
||||
logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cfdServer) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
|
||||
c.config.Metrics.incrementResponses(c.connectionIndex, "200")
|
||||
logger := c.logger
|
||||
if cfRay != "" {
|
||||
logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
|
||||
} else if lbProbe {
|
||||
logger.Debugf("Response to Load Balancer health check %s", r.Status)
|
||||
} else {
|
||||
logger.Infof("%s", r.Status)
|
||||
}
|
||||
logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
|
||||
|
||||
if contentLen := r.ContentLength; contentLen == -1 {
|
||||
logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
|
||||
} else {
|
||||
logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
|
||||
}
|
||||
}
|
||||
|
||||
type WebsocketResp interface {
|
||||
WriteRespHeaders(*http.Response) error
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
type http2WebsocketResp struct {
|
||||
pr *io.PipeReader
|
||||
w http.ResponseWriter
|
||||
}
|
||||
|
||||
func newWebsocketBody(w http.ResponseWriter, r *http.Request, logger logger.Service) *http2WebsocketResp {
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
n, err := io.Copy(pw, r.Body)
|
||||
logger.Errorf("websocket body copy ended, err: %v, bytes: %d", err, n)
|
||||
}()
|
||||
return &http2WebsocketResp{pr: pr, w: w}
|
||||
}
|
||||
|
||||
func (wr *http2WebsocketResp) WriteRespHeaders(resp *http.Response) error {
|
||||
dest := wr.w.Header()
|
||||
for name, values := range resp.Header {
|
||||
for _, v := range values {
|
||||
dest.Add(name, v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wr *http2WebsocketResp) Read(p []byte) (n int, err error) {
|
||||
return wr.pr.Read(p)
|
||||
}
|
||||
|
||||
func (wr *http2WebsocketResp) Write(p []byte) (n int, err error) {
|
||||
return wr.w.Write(p)
|
||||
}
|
||||
|
||||
type h2muxWebsocketResp struct {
|
||||
*h2mux.MuxedStream
|
||||
}
|
||||
|
||||
func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error {
|
||||
return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
|
||||
}
|
120
origin/tunnel.go
120
origin/tunnel.go
@@ -17,7 +17,9 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/buffer"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
@@ -30,6 +32,7 @@ import (
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
@@ -304,7 +307,14 @@ func ServeTunnel(
|
||||
connectionTag := uint8ToString(connectionIndex)
|
||||
|
||||
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
|
||||
return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh)
|
||||
tlsConn, err := RegisterConnection(ctx, config, connectionIndex, uint8(backoff.retries), addr)
|
||||
if err != nil {
|
||||
logger.Errorf("Register connectio error: %+v", err)
|
||||
return err, true
|
||||
}
|
||||
connectedFuse.Fuse(true)
|
||||
backoff.SetGracePeriod()
|
||||
return serveNamedTunnel(ctx, config, tlsConn, connectionIndex, reconnectCh)
|
||||
}
|
||||
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
@@ -332,10 +342,6 @@ func ServeTunnel(
|
||||
}
|
||||
}()
|
||||
|
||||
if config.NamedTunnel != nil {
|
||||
return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr, uint8(backoff.retries))
|
||||
}
|
||||
|
||||
if config.UseReconnectToken && connectedFuse.Value() {
|
||||
err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
|
||||
if err == nil {
|
||||
@@ -426,7 +432,55 @@ func ServeTunnel(
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func RegisterConnection(
|
||||
func serveNamedTunnel(
|
||||
ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
tlsConn net.Conn,
|
||||
connectionIndex uint8,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) (err error, recoverable bool) {
|
||||
originURLStr, err := validation.ValidateUrl(config.OriginUrl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse origin URL %#v", config.OriginUrl), false
|
||||
}
|
||||
originURL, err := url.Parse(originURLStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse origin URL %#v", originURLStr), false
|
||||
}
|
||||
|
||||
originClient := config.HTTPTransport
|
||||
if originClient == nil {
|
||||
originClient = http.DefaultTransport
|
||||
}
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
cfdServer := &cfdServer{
|
||||
httpServer: &http2.Server{},
|
||||
originClient: originClient,
|
||||
logger: config.Logger,
|
||||
originURL: originURL,
|
||||
connectionIndex: uint8ToString(connectionIndex),
|
||||
config: config,
|
||||
}
|
||||
cfdServer.serve(serveCtx, tlsConn)
|
||||
return fmt.Errorf("Connection with edge closed")
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case reconnect := <-reconnectCh:
|
||||
return &reconnect
|
||||
case <-serveCtx.Done():
|
||||
return nil
|
||||
}
|
||||
})
|
||||
|
||||
err = errGroup.Wait()
|
||||
return err, true
|
||||
}
|
||||
|
||||
func RegisterConnectionWithH2Mux(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
@@ -470,6 +524,52 @@ func RegisterConnection(
|
||||
return nil
|
||||
}
|
||||
|
||||
func RegisterConnection(
|
||||
ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
connectionID uint8,
|
||||
numPreviousAttempts uint8,
|
||||
addr *net.TCPAddr,
|
||||
) (net.Conn, error) {
|
||||
originCert, err := tls.X509KeyPair(config.OriginCert, config.OriginCert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig := config.TlsConfig
|
||||
tlsConfig.Certificates = []tls.Certificate{originCert}
|
||||
tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpcTransport := tunnelrpc.NewTransportLogger(config.Logger, rpc.StreamTransport(&persistentConn{tlsServerConn}))
|
||||
rpcConn := rpc.NewConn(
|
||||
rpcTransport,
|
||||
tunnelrpc.ConnLog(config.Logger),
|
||||
)
|
||||
rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx), Conn: rpcConn}
|
||||
connDetail, err := rpcClient.RegisterConnection(
|
||||
ctx,
|
||||
config.NamedTunnel.Auth,
|
||||
config.NamedTunnel.ID,
|
||||
connectionID,
|
||||
config.ConnectionOptions(tlsServerConn.LocalAddr().String(), numPreviousAttempts),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionID, connDetail.Location, connDetail.UUID)
|
||||
rpcTransport.Close()
|
||||
// Closing the client will also close the connection
|
||||
rpcClient.Close()
|
||||
|
||||
flushMessage := make([]byte, 8)
|
||||
buf := make([]byte, len(flushMessage))
|
||||
tlsServerConn.Write(buf)
|
||||
|
||||
return tlsServerConn, nil
|
||||
}
|
||||
|
||||
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
|
||||
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
|
||||
return &serverRegisterTunnelError{
|
||||
@@ -698,7 +798,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
var resp *http.Response
|
||||
var respErr error
|
||||
if websocket.IsWebSocketUpgrade(req) {
|
||||
resp, respErr = h.serveWebsocket(stream, req, rule)
|
||||
resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule)
|
||||
} else {
|
||||
resp, respErr = h.serveHTTP(stream, req, rule)
|
||||
}
|
||||
@@ -725,7 +825,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
|
||||
return req, rule, nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||
func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||
req.Header.Set("Host", hostHeader)
|
||||
req.Host = hostHeader
|
||||
@@ -740,13 +840,13 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response))
|
||||
err = wsResp.WriteRespHeaders(response)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
websocket.Stream(conn.UnderlyingConn(), wsResp)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user