mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:29:57 +00:00
TUN-1977: Validate OriginConfig has valid URL, and use scheme to determine if a HTTPOriginService is expecting HTTP or Unix
This commit is contained in:
@@ -134,49 +134,18 @@ type OriginConfig interface {
|
||||
}
|
||||
|
||||
type HTTPOriginConfig struct {
|
||||
URL OriginAddr `capnp:"url"`
|
||||
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
|
||||
DialDualStack bool
|
||||
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
|
||||
TLSVerify bool `capnp:"tlsVerify"`
|
||||
OriginCAPool string
|
||||
OriginServerName string
|
||||
MaxIdleConnections uint64
|
||||
IdleConnectionTimeout time.Duration
|
||||
ProxyConnectTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
ChunkedEncoding bool
|
||||
}
|
||||
|
||||
type OriginAddr interface {
|
||||
Addr() string
|
||||
}
|
||||
|
||||
type HTTPURL struct {
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
func (ha *HTTPURL) Addr() string {
|
||||
return ha.URL.String()
|
||||
}
|
||||
|
||||
func (ha *HTTPURL) capnpHTTPURL() *CapnpHTTPURL {
|
||||
return &CapnpHTTPURL{
|
||||
URL: ha.URL.String(),
|
||||
}
|
||||
}
|
||||
|
||||
// URL for a HTTP origin, capnp doesn't have native support for URL, so represent it as string
|
||||
type CapnpHTTPURL struct {
|
||||
URL string `capnp:"url"`
|
||||
}
|
||||
|
||||
type UnixPath struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func (up *UnixPath) Addr() string {
|
||||
return up.Path
|
||||
URLString string `capnp:"urlString"`
|
||||
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
|
||||
DialDualStack bool
|
||||
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
|
||||
TLSVerify bool `capnp:"tlsVerify"`
|
||||
OriginCAPool string
|
||||
OriginServerName string
|
||||
MaxIdleConnections uint64
|
||||
IdleConnectionTimeout time.Duration
|
||||
ProxyConnectionTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
ChunkedEncoding bool
|
||||
}
|
||||
|
||||
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||
@@ -184,8 +153,9 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialContext := (&net.Dialer{
|
||||
Timeout: hc.ProxyConnectTimeout,
|
||||
Timeout: hc.ProxyConnectionTimeout,
|
||||
KeepAlive: hc.TCPKeepAlive,
|
||||
DualStack: hc.DialDualStack,
|
||||
}).DialContext
|
||||
@@ -202,18 +172,22 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||
IdleConnTimeout: hc.IdleConnectionTimeout,
|
||||
ExpectContinueTimeout: hc.ExpectContinueTimeout,
|
||||
}
|
||||
if unixPath, ok := hc.URL.(*UnixPath); ok {
|
||||
url, err := url.Parse(hc.URLString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "%s is not a valid URL", hc.URLString)
|
||||
}
|
||||
if url.Scheme == "unix" {
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialContext(ctx, "unix", unixPath.Addr())
|
||||
return dialContext(ctx, "unix", url.Host)
|
||||
}
|
||||
}
|
||||
return originservice.NewHTTPService(transport, hc.URL.Addr(), hc.ChunkedEncoding), nil
|
||||
return originservice.NewHTTPService(transport, url, hc.ChunkedEncoding), nil
|
||||
}
|
||||
|
||||
func (_ *HTTPOriginConfig) isOriginConfig() {}
|
||||
|
||||
type WebSocketOriginConfig struct {
|
||||
URL string `capnp:"url"`
|
||||
URLString string `capnp:"urlString"`
|
||||
TLSVerify bool `capnp:"tlsVerify"`
|
||||
OriginCAPool string
|
||||
OriginServerName string
|
||||
@@ -229,7 +203,12 @@ func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error)
|
||||
ServerName: wsc.OriginServerName,
|
||||
InsecureSkipVerify: wsc.TLSVerify,
|
||||
}
|
||||
return originservice.NewWebSocketService(tlsConfig, wsc.URL)
|
||||
|
||||
url, err := url.Parse(wsc.URLString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "%s is not a valid URL", wsc.URLString)
|
||||
}
|
||||
return originservice.NewWebSocketService(tlsConfig, url)
|
||||
}
|
||||
|
||||
func (_ *WebSocketOriginConfig) isOriginConfig() {}
|
||||
@@ -550,115 +529,12 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
|
||||
}
|
||||
|
||||
func MarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig, p *HTTPOriginConfig) error {
|
||||
switch originAddr := p.URL.(type) {
|
||||
case *HTTPURL:
|
||||
ss, err := s.OriginAddr().NewHttp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := MarshalHTTPURL(ss, originAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
case *UnixPath:
|
||||
ss, err := s.OriginAddr().NewUnix()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := MarshalUnixPath(ss, originAddr); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unknown type for OriginAddr: %T", originAddr)
|
||||
}
|
||||
s.SetTcpKeepAlive(p.TCPKeepAlive.Nanoseconds())
|
||||
s.SetDialDualStack(p.DialDualStack)
|
||||
s.SetTlsHandshakeTimeout(p.TLSHandshakeTimeout.Nanoseconds())
|
||||
s.SetTlsVerify(p.TLSVerify)
|
||||
s.SetOriginCAPool(p.OriginCAPool)
|
||||
s.SetOriginServerName(p.OriginServerName)
|
||||
s.SetMaxIdleConnections(p.MaxIdleConnections)
|
||||
s.SetIdleConnectionTimeout(p.IdleConnectionTimeout.Nanoseconds())
|
||||
s.SetProxyConnectionTimeout(p.ProxyConnectTimeout.Nanoseconds())
|
||||
s.SetExpectContinueTimeout(p.ExpectContinueTimeout.Nanoseconds())
|
||||
s.SetChunkedEncoding(p.ChunkedEncoding)
|
||||
return nil
|
||||
return pogs.Insert(tunnelrpc.HTTPOriginConfig_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func UnmarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig) (*HTTPOriginConfig, error) {
|
||||
p := new(HTTPOriginConfig)
|
||||
switch s.OriginAddr().Which() {
|
||||
case tunnelrpc.HTTPOriginConfig_originAddr_Which_http:
|
||||
ss, err := s.OriginAddr().Http()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
originAddr, err := UnmarshalCapnpHTTPURL(ss)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.URL = originAddr
|
||||
case tunnelrpc.HTTPOriginConfig_originAddr_Which_unix:
|
||||
ss, err := s.OriginAddr().Unix()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
originAddr, err := UnmarshalUnixPath(ss)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.URL = originAddr
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown type for OriginAddr: %T", s.OriginAddr().Which())
|
||||
}
|
||||
p.TCPKeepAlive = time.Duration(s.TcpKeepAlive())
|
||||
p.DialDualStack = s.DialDualStack()
|
||||
p.TLSHandshakeTimeout = time.Duration(s.TlsHandshakeTimeout())
|
||||
p.TLSVerify = s.TlsVerify()
|
||||
originCAPool, err := s.OriginCAPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.OriginCAPool = originCAPool
|
||||
originServerName, err := s.OriginServerName()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.OriginServerName = originServerName
|
||||
p.MaxIdleConnections = s.MaxIdleConnections()
|
||||
p.IdleConnectionTimeout = time.Duration(s.IdleConnectionTimeout())
|
||||
p.ProxyConnectTimeout = time.Duration(s.ProxyConnectionTimeout())
|
||||
p.ExpectContinueTimeout = time.Duration(s.ExpectContinueTimeout())
|
||||
p.ChunkedEncoding = s.ChunkedEncoding()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func MarshalHTTPURL(s tunnelrpc.CapnpHTTPURL, p *HTTPURL) error {
|
||||
return pogs.Insert(tunnelrpc.CapnpHTTPURL_TypeID, s.Struct, p.capnpHTTPURL())
|
||||
}
|
||||
|
||||
func UnmarshalCapnpHTTPURL(s tunnelrpc.CapnpHTTPURL) (*HTTPURL, error) {
|
||||
p := new(CapnpHTTPURL)
|
||||
err := pogs.Extract(p, tunnelrpc.CapnpHTTPURL_TypeID, s.Struct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url, err := url.Parse(p.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &HTTPURL{
|
||||
URL: url,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func MarshalUnixPath(s tunnelrpc.UnixPath, p *UnixPath) error {
|
||||
err := pogs.Insert(tunnelrpc.UnixPath_TypeID, s.Struct, p)
|
||||
return err
|
||||
}
|
||||
|
||||
func UnmarshalUnixPath(s tunnelrpc.UnixPath) (*UnixPath, error) {
|
||||
p := new(UnixPath)
|
||||
err := pogs.Extract(p, tunnelrpc.UnixPath_TypeID, s.Struct)
|
||||
err := pogs.Extract(p, tunnelrpc.HTTPOriginConfig_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user