mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:59:58 +00:00
TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input
This change extracts the need for EstablishConnection to know about a request's entire context. It also removes the concern of populating the http.Response from EstablishConnection's responsibilities.
This commit is contained in:
@@ -6,9 +6,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
|
||||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
||||
EstablishConnection(dest string) (OriginConnection, error)
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
||||
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := getRequestHost(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
originConn := &tcpConnection{
|
||||
conn: conn,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
return originConn, nil
|
||||
}
|
||||
|
||||
// getRequestHost returns the host of the http.Request.
|
||||
func getRequestHost(r *http.Request) (string, error) {
|
||||
if r.Host != "" {
|
||||
return r.Host, nil
|
||||
}
|
||||
if r.URL != nil {
|
||||
return r.URL.Host, nil
|
||||
}
|
||||
return "", errors.New("host not found")
|
||||
}
|
||||
|
||||
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||
var err error
|
||||
dest := o.dest
|
||||
if o.isBastion {
|
||||
dest, err = carrier.ResolveBastionDest(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !o.isBastion {
|
||||
dest = o.dest
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
originConn := &tcpOverWSConnection{
|
||||
conn: conn,
|
||||
streamHandler: o.streamHandler,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: websocket.NewResponseHeader(r),
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
return originConn, nil
|
||||
|
||||
}
|
||||
|
||||
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
originConn := o.conn
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: websocket.NewResponseHeader(r),
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||
return o.conn, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user