mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:59:58 +00:00
TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).
This commit is contained in:
@@ -2,8 +2,10 @@ package ingress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@@ -12,7 +14,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
@@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
|
||||
switch req.URL.Scheme {
|
||||
case "ws":
|
||||
req.URL.Scheme = "http"
|
||||
case "wss":
|
||||
req.URL.Scheme = "https"
|
||||
}
|
||||
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
|
||||
return o.newWebsocketProxyConnection(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
req.ContentLength = 0
|
||||
req.Body = nil
|
||||
|
||||
resp, err := o.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
toClose := resp.Body
|
||||
defer func() {
|
||||
if toClose != nil {
|
||||
_ = toClose.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
|
||||
}
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
|
||||
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
|
||||
}
|
||||
|
||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
return nil, nil, errUnsupportedConnectionType
|
||||
}
|
||||
conn := wsProxyConnection{
|
||||
rwc: rwc,
|
||||
}
|
||||
// clear to prevent defer from closing
|
||||
toClose = nil
|
||||
|
||||
return &conn, resp, nil
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
|
Reference in New Issue
Block a user