TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService

This commit is contained in:
cthuang
2021-02-05 13:01:53 +00:00
committed by Nuno Diegues
parent 5943808746
commit ab4dda5427
10 changed files with 563 additions and 212 deletions

View File

@@ -7,19 +7,22 @@ import (
"net/url"
"strings"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
var (
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
type HTTPOriginProxy interface {
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper
}
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
}
@@ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
// TODO: TUN-3636: establish connection to origins over UDS
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
}
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host
@@ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection,
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return newWSConnection(o.transport, req)
return newWSConnection(o.transport.TLSClientConfig, req)
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss"
return newWSConnection(o.transport, req)
return newWSConnection(o.transport.TLSClientConfig, req)
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := o.destination(r)
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := getRequestHost(r)
if err != nil {
return nil, nil, err
}
conn, err := o.client.connect(r, dest)
return conn, nil, err
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
}
originConn := &tcpConnection{
conn: conn,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
}
// getRequestHost returns the host of the http.Request.
@@ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) {
return "", errors.New("host not found")
}
func (o *bridgeService) destination(r *http.Request) (string, error) {
if connection.IsTCPStream(r) {
return getRequestHost(r)
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
var err error
dest := o.dest
if o.isBastion {
dest, err = o.bastionDest(r)
if err != nil {
return nil, nil, err
}
}
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
}
originConn := &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
}
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
if jumpDestination == "" {
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
@@ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) {
func removePath(dest string) string {
return strings.SplitN(dest, "/", 2)[0]
}
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
conn, err := o.client.connect(r, o.dest)
return conn, nil, err
}
type tcpClient struct {
streamHandler streamHandlerFunc
}
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &tcpConnection{
conn: conn,
streamHandler: c.streamHandler,
}, nil
}