mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:59:58 +00:00
TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService
This commit is contained in:
@@ -7,19 +7,22 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
type HTTPOriginProxy interface {
|
||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||
http.RoundTripper
|
||||
}
|
||||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
||||
}
|
||||
@@ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// TODO: TUN-3636: establish connection to origins over UDS
|
||||
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||
}
|
||||
|
||||
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.url.Host
|
||||
@@ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection,
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return newWSConnection(o.transport, req)
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport, req)
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
||||
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := o.destination(r)
|
||||
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := getRequestHost(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := o.client.connect(r, dest)
|
||||
return conn, nil, err
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
originConn := &tcpConnection{
|
||||
conn: conn,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
}
|
||||
|
||||
// getRequestHost returns the host of the http.Request.
|
||||
@@ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) {
|
||||
return "", errors.New("host not found")
|
||||
}
|
||||
|
||||
func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||
if connection.IsTCPStream(r) {
|
||||
return getRequestHost(r)
|
||||
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
var err error
|
||||
dest := o.dest
|
||||
if o.isBastion {
|
||||
dest, err = o.bastionDest(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
originConn := &tcpOverWSConnection{
|
||||
conn: conn,
|
||||
streamHandler: o.streamHandler,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: websocket.NewResponseHeader(r),
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
|
||||
}
|
||||
|
||||
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
|
||||
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
||||
if jumpDestination == "" {
|
||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||
@@ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||
func removePath(dest string) string {
|
||||
return strings.SplitN(dest, "/", 2)[0]
|
||||
}
|
||||
|
||||
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
conn, err := o.client.connect(r, o.dest)
|
||||
return conn, nil, err
|
||||
|
||||
}
|
||||
|
||||
type tcpClient struct {
|
||||
streamHandler streamHandlerFunc
|
||||
}
|
||||
|
||||
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tcpConnection{
|
||||
conn: conn,
|
||||
streamHandler: c.streamHandler,
|
||||
}, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user