TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input

This change extracts the need for EstablishConnection to know about a
request's entire context. It also removes the concern of populating the
http.Response from EstablishConnection's responsibilities.
This commit is contained in:
Sudarsan Reddy
2021-07-01 19:30:26 +01:00
parent f1b57526b3
commit d678584d89
7 changed files with 69 additions and 94 deletions

View File

@@ -6,9 +6,6 @@ import (
"net/http"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/websocket"
)
var (
@@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
EstablishConnection(dest string) (OriginConnection, error)
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := getRequestHost(r)
if err != nil {
return nil, nil, err
}
func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
return nil, err
}
originConn := &tcpConnection{
conn: conn,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
return originConn, nil
}
// getRequestHost returns the host of the http.Request.
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
}
if r.URL != nil {
return r.URL.Host, nil
}
return "", errors.New("host not found")
}
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
var err error
dest := o.dest
if o.isBastion {
dest, err = carrier.ResolveBastionDest(r)
if err != nil {
return nil, nil, err
}
if !o.isBastion {
dest = o.dest
}
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
return nil, err
}
originConn := &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
return originConn, nil
}
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
originConn := o.conn
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
return o.conn, nil
}