TUN-3817: Adds tests for websocket based streaming regression

This commit is contained in:
Sudarsan Reddy
2021-02-02 18:27:50 +00:00
committed by Nuno Diegues
parent 6681d179dc
commit a6c2348127
5 changed files with 374 additions and 167 deletions

View File

@@ -5,6 +5,7 @@ import (
"net"
"net/http"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/websocket"
gws "github.com/gorilla/websocket"
)
@@ -15,6 +16,7 @@ type OriginConnection interface {
// Stream should generally be implemented as a bidirectional io.Copy.
Stream(tunnelConn io.ReadWriter)
Close()
Type() connection.Type
}
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn)
@@ -57,6 +59,10 @@ func (tc *tcpConnection) Close() {
tc.conn.Close()
}
func (*tcpConnection) Type() connection.Type {
return connection.TypeTCP
}
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
type wsConnection struct {
@@ -73,6 +79,10 @@ func (wsc *wsConnection) Close() {
wsc.wsConn.Close()
}
func (wsc *wsConnection) Type() connection.Type {
return connection.TypeWebsocket
}
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
d := &gws.Dialer{
TLSClientConfig: transport.TLSClientConfig,

View File

@@ -9,6 +9,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
@@ -39,6 +40,12 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) {
req.URL.Host = o.url.Host
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
return newWSConnection(o.transport, req)
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()