diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 27911960..049dc25a 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -395,8 +395,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan uploadManager.Start() } - sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) - server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) + localServerAddress := "127.0.0.1:" + c.String(sshPortFlag) + server, err := sshserver.New(logManager, logger, version, localServerAddress, c.String("hostname"), shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) if err != nil { msg := "Cannot create new SSH Server" logger.WithError(err).Error(msg) @@ -411,7 +411,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan // TODO: remove when declarative tunnels are implemented. close(shutdownC) }() - c.Set("url", "ssh://"+sshServerAddress) + c.Set("url", "ssh://"+localServerAddress) } if host := hostnameFromURI(c.String("url")); host != "" { diff --git a/sshgen/sshgen.go b/sshgen/sshgen.go index 12538c84..2d6f326b 100644 --- a/sshgen/sshgen.go +++ b/sshgen/sshgen.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" - "errors" "fmt" "io" "io/ioutil" @@ -20,6 +19,7 @@ import ( cfpath "github.com/cloudflare/cloudflared/cmd/cloudflared/path" "github.com/coreos/go-oidc/jose" homedir "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" gossh "golang.org/x/crypto/ssh" ) @@ -73,48 +73,54 @@ func GenerateShortLivedCertificate(appURL *url.URL, token string) error { // handleCertificateGeneration takes a JWT and uses it build a signPayload // to send to the Sign endpoint with the public key from the keypair it generated func handleCertificateGeneration(token, fullName string) (string, error) { + pub, err := generateKeyPair(fullName) + if err != nil { + return "", err + } + + return SignCert(token, string(pub)) +} + +func SignCert(token, pubKey string) (string, error) { if token == "" { return "", errors.New("invalid token") } jwt, err := jose.ParseJWT(token) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse JWT") } claims, err := jwt.Claims() if err != nil { - return "", err + return "", errors.Wrap(err, "failed to retrieve JWT claims") } issuer, _, err := claims.StringClaim("iss") if err != nil { - return "", err - } - - pub, err := generateKeyPair(fullName) - if err != nil { - return "", err + return "", errors.Wrap(err, "failed to retrieve JWT iss") } buf, err := json.Marshal(&signPayload{ - PublicKey: string(pub), + PublicKey: pubKey, JWT: token, Issuer: issuer, }) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to marshal signPayload") } - var res *http.Response if mockRequest != nil { res, err = mockRequest(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf)) } else { - res, err = http.Post(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf)) + client := http.Client{ + Timeout: 10 * time.Second, + } + res, err = client.Post(issuer+signEndpoint, "application/json", bytes.NewBuffer(buf)) } if err != nil { - return "", err + return "", errors.Wrap(err, "failed to send request") } defer res.Body.Close() @@ -130,9 +136,9 @@ func handleCertificateGeneration(token, fullName string) (string, error) { var signRes signResponse if err := decoder.Decode(&signRes); err != nil { - return "", err + return "", errors.Wrap(err, "failed to decode HTTP response") } - return signRes.Certificate, err + return signRes.Certificate, nil } // generateKeyPair creates a EC keypair (P256) and stores them in the homedir. diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index f8e1adba..eeb05f23 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -3,6 +3,9 @@ package sshserver import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "encoding/binary" "encoding/json" "fmt" @@ -13,6 +16,7 @@ import ( "strings" "time" + "github.com/cloudflare/cloudflared/sshgen" "github.com/cloudflare/cloudflared/sshlog" "github.com/gliderlabs/ssh" "github.com/google/uuid" @@ -30,8 +34,9 @@ const ( auditEventShell = "shell" sshContextSessionID = "sessionID" sshContextEventLogger = "eventLogger" - sshContextDestination = "sshDest" - sshPreambleLength = 4 + sshContextPreamble = "sshPreamble" + sshContextSSHClient = "sshClient" + SSHPreambleLength = 4 ) type auditEvent struct { @@ -41,31 +46,53 @@ type auditEvent struct { User string `json:"user,omitempty"` Login string `json:"login,omitempty"` Datetime string `json:"datetime,omitempty"` + Hostname string `json:"hostname,omitempty"` Destination string `json:"destination,omitempty"` } +// sshConn wraps the incoming net.Conn and a cleanup function +// This is done to allow the outgoing SSH client to be retrieved and closed when the conn itself is closed. +type sshConn struct { + net.Conn + cleanupFunc func() +} + +// close calls the cleanupFunc before closing the conn +func (c sshConn) Close() error { + c.cleanupFunc() + return c.Conn.Close() +} + type SSHProxy struct { ssh.Server + hostname string logger *logrus.Logger shutdownC chan struct{} caCert ssh.PublicKey logManager sshlog.Manager } +type SSHPreamble struct { + Destination string + JWT string +} + // New creates a new SSHProxy and configures its host keys and authentication by the data provided -func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) { +func New(logManager sshlog.Manager, logger *logrus.Logger, version, localAddress, hostname string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) { sshProxy := SSHProxy{ + hostname: hostname, logger: logger, shutdownC: shutdownC, logManager: logManager, } sshProxy.Server = ssh.Server{ - Addr: address, - MaxTimeout: maxTimeout, - IdleTimeout: idleTimeout, - Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS), - ConnCallback: sshProxy.connCallback, + Addr: localAddress, + MaxTimeout: maxTimeout, + IdleTimeout: idleTimeout, + Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS), + PublicKeyHandler: sshProxy.proxyAuthCallback, + ConnCallback: sshProxy.connCallback, ChannelHandlers: map[string]ssh.ChannelHandler{ "default": sshProxy.channelHandler, }, @@ -92,23 +119,54 @@ func (s *SSHProxy) Start() error { return s.ListenAndServe() } +// proxyAuthCallback attempts to connect to ultimate SSH destination. If successful, it allows the incoming connection +// to connect to the proxy and saves the outgoing SSH client to the context. Otherwise, no connection to the +// the proxy is allowed. +func (s *SSHProxy) proxyAuthCallback(ctx ssh.Context, key ssh.PublicKey) bool { + client, err := s.dialDestination(ctx) + if err != nil { + return false + } + ctx.SetValue(sshContextSSHClient, client) + return true +} + +// connCallback reads the preamble sent from the proxy server and saves an audit event logger to the context. +// If any errors occur, the connection is terminated by returning nil from the callback. func (s *SSHProxy) connCallback(ctx ssh.Context, conn net.Conn) net.Conn { // AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing. // TODO: Remove this time.Sleep(10 * time.Millisecond) - if err := s.configureSSHDestination(conn, ctx); err != nil { - if err != io.EOF { - s.logger.WithError(err).Error("failed to read SSH destination") + preamble, err := s.readPreamble(conn) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + s.logger.Warn("Could not establish session. Client likely does not have --destination set and is using old-style ssh config") + } else if err != io.EOF { + s.logger.WithError(err).Error("failed to read SSH preamble") } return nil } + ctx.SetValue(sshContextPreamble, preamble) - if err := s.configureLogger(ctx); err != nil { + logger, sessionID, err := s.auditLogger() + if err != nil { s.logger.WithError(err).Error("failed to configure logger") return nil } - return conn + ctx.SetValue(sshContextEventLogger, logger) + ctx.SetValue(sshContextSessionID, sessionID) + + // attempts to retrieve and close the outgoing ssh client when the incoming conn is closed. + // If no client exists, the conn is being closed before the PublicKeyCallback was called (where the client is created). + cleanupFunc := func() { + client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client) + if ok && client != nil { + client.Close() + } + } + + return sshConn{conn, cleanupFunc} } // channelHandler proxies incoming and outgoing SSH traffic back and forth over an SSH Channel @@ -129,13 +187,12 @@ func (s *SSHProxy) channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newCh } defer localChan.Close() - // AUTH-2136 TODO: multiplex ssh client between channels - client, err := s.createSSHClient(ctx) - if err != nil { - s.logger.WithError(err).Error("Failed to dial remote server") + // client will be closed when the sshConn is closed + client, ok := ctx.Value(sshContextSSHClient).(*gossh.Client) + if !ok { + s.logger.Error("Could not retrieve client from context") return } - defer client.Close() remoteChan, remoteChanReqs, err := client.OpenChannel(newChan.ChannelType(), newChan.ExtraData()) if err != nil { @@ -196,54 +253,116 @@ func (s *SSHProxy) proxyChannel(localChan, remoteChan gossh.Channel, localChanRe } } -// configureSSHDestination reads a preamble from the SSH connection before any SSH traffic is sent. -// This preamble contains the ultimate SSH destination the proxy will connect too. -// The first 4 bytes contain the length of the destination which follows immediately. -func (s *SSHProxy) configureSSHDestination(conn net.Conn, ctx ssh.Context) error { - size := make([]byte, sshPreambleLength) +// readPreamble reads a preamble from the SSH connection before any SSH traffic is sent. +// This preamble is a JSON encoded struct containing the users JWT and ultimate destination. +// The first 4 bytes contain the length of the preamble which follows immediately. +func (s *SSHProxy) readPreamble(conn net.Conn) (*SSHPreamble, error) { + // Set conn read deadline while reading preamble to prevent hangs if preamble wasnt sent. + if err := conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil { + return nil, errors.Wrap(err, "failed to set conn deadline") + } + defer func() { + if err := conn.SetReadDeadline(time.Time{}); err != nil { + s.logger.WithError(err).Error("Failed to unset conn read deadline") + } + }() + + size := make([]byte, SSHPreambleLength) if _, err := io.ReadFull(conn, size); err != nil { - return err + return nil, err } payloadLength := binary.BigEndian.Uint32(size) - data := make([]byte, payloadLength) - if _, err := io.ReadFull(conn, data); err != nil { - return err + payload := make([]byte, payloadLength) + if _, err := io.ReadFull(conn, payload); err != nil { + return nil, err } - destAddr := string(data) - destUrl, err := url.Parse(destAddr) + var preamble SSHPreamble + err := json.Unmarshal(payload, &preamble) if err != nil { - return errors.Wrap(err, "failed to parse URL") + return nil, err + } + + destUrl, err := url.Parse(preamble.Destination) + if err != nil { + return nil, errors.Wrap(err, "failed to parse URL") } if destUrl.Port() == "" { - destAddr += ":22" + preamble.Destination += ":22" } - ctx.SetValue(sshContextDestination, destAddr) - return nil + return &preamble, nil } -// createSSHClient creates a new SSH client and dials the destination server -func (s *SSHProxy) createSSHClient(ctx ssh.Context) (*gossh.Client, error) { +// dialDestination creates a new SSH client and dials the destination server +func (s *SSHProxy) dialDestination(ctx ssh.Context) (*gossh.Client, error) { + preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble) + if !ok { + msg := "failed to retrieve SSH preamble from context" + s.logger.Error(msg) + return nil, errors.New(msg) + } + + signer, err := s.genSSHSigner(preamble.JWT) + if err != nil { + s.logger.WithError(err).Error("Failed to generate signed short lived cert") + return nil, err + } + clientConfig := &gossh.ClientConfig{ User: ctx.User(), // AUTH-2103 TODO: proper host key check HostKeyCallback: gossh.InsecureIgnoreHostKey(), - // AUTH-2114 TODO: replace with short lived cert auth - Auth: []gossh.AuthMethod{gossh.Password("test")}, - ClientVersion: ctx.ServerVersion(), + Auth: []gossh.AuthMethod{gossh.PublicKeys(signer)}, + ClientVersion: ctx.ServerVersion(), } - address, ok := ctx.Value(sshContextDestination).(string) - if !ok { - return nil, errors.New("failed to retrieve SSH destination from context") - } - client, err := gossh.Dial("tcp", address, clientConfig) + client, err := gossh.Dial("tcp", preamble.Destination, clientConfig) if err != nil { + s.logger.WithError(err).Info("Failed to connect to destination SSH server") return nil, err } return client, nil } +// Generates a key pair and sends public key to get signed by CA +func (s *SSHProxy) genSSHSigner(jwt string) (gossh.Signer, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, errors.Wrap(err, "failed to generate ecdsa key pair") + } + + pub, err := gossh.NewPublicKey(&key.PublicKey) + if err != nil { + return nil, errors.Wrap(err, "failed to convert ecdsa public key to SSH public key") + } + + pubBytes := gossh.MarshalAuthorizedKey(pub) + signedCertBytes, err := sshgen.SignCert(jwt, string(pubBytes)) + if err != nil { + return nil, errors.Wrap(err, "failed to retrieve cert from SSHCAAPI") + } + + signedPub, _, _, _, err := gossh.ParseAuthorizedKey([]byte(signedCertBytes)) + if err != nil { + return nil, errors.Wrap(err, "failed to parse SSH public key") + } + + cert, ok := signedPub.(*gossh.Certificate) + if !ok { + return nil, errors.Wrap(err, "failed to assert public key as certificate") + } + signer, err := gossh.NewSignerFromKey(key) + if err != nil { + return nil, errors.Wrap(err, "failed to create signer") + } + + certSigner, err := gossh.NewCertSigner(cert, signer) + if err != nil { + return nil, errors.Wrap(err, "failed to create cert signer") + } + return certSigner, nil +} + // forwardChannelRequest sends request req to SSH channel sshChan, waits for reply, and sends the reply back. func (s *SSHProxy) forwardChannelRequest(sshChan gossh.Channel, req *gossh.Request) error { reply, err := sshChan.SendRequest(req.Type, req.WantReply, req.Payload) @@ -282,20 +401,18 @@ func (s *SSHProxy) logChannelRequest(req *gossh.Request, conn *gossh.ServerConn, s.logAuditEvent(conn, event, eventType, ctx) } -func (s *SSHProxy) configureLogger(ctx ssh.Context) error { +func (s *SSHProxy) auditLogger() (io.WriteCloser, string, error) { sessionUUID, err := uuid.NewRandom() if err != nil { - return errors.Wrap(err, "failed to create sessionID") + return nil, "", errors.Wrap(err, "failed to create sessionID") } sessionID := sessionUUID.String() writer, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger) if err != nil { - return errors.Wrap(err, "failed to create logger") + return nil, "", errors.Wrap(err, "failed to create logger") } - ctx.SetValue(sshContextEventLogger, writer) - ctx.SetValue(sshContextSessionID, sessionID) - return nil + return writer, sessionID, nil } func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string, ctx ssh.Context) { @@ -306,9 +423,12 @@ func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string return } - destination, destOk := ctx.Value(sshContextDestination).(string) - if !destOk { - s.logger.Error("Failed to retrieve SSH destination from context") + var destination string + preamble, ok := ctx.Value(sshContextPreamble).(*SSHPreamble) + if ok { + destination = preamble.Destination + } else { + s.logger.Error("Failed to retrieve SSH preamble from context") } ae := auditEvent{ @@ -318,6 +438,7 @@ func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string User: conn.User(), Login: conn.User(), Datetime: time.Now().UTC().Format(time.RFC3339), + Hostname: s.hostname, Destination: destination, } data, err := json.Marshal(&ae) diff --git a/sshserver/sshserver_windows.go b/sshserver/sshserver_windows.go index d5f89744..085d6e08 100644 --- a/sshserver/sshserver_windows.go +++ b/sshserver/sshserver_windows.go @@ -13,7 +13,12 @@ import ( type SSHServer struct{} -func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { +type SSHPreamble struct { + Destination string + JWT string +} + +func New(_ sshlog.Manager, _ *logrus.Logger, _, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { return nil, errors.New("cloudflared ssh server is not supported on windows") } diff --git a/websocket/websocket.go b/websocket/websocket.go index 26436610..4ada6190 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -6,12 +6,14 @@ import ( "crypto/tls" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "io" "net" "net/http" "time" + "github.com/cloudflare/cloudflared/sshserver" "github.com/gorilla/websocket" "github.com/sirupsen/logrus" ) @@ -155,9 +157,11 @@ func StartProxyServer(logger *logrus.Logger, listener net.Listener, remote strin conn.Close() }() + token := r.Header.Get("cf-access-token") if destination := r.Header.Get("CF-Access-SSH-Destination"); destination != "" { - if err := sendSSHDestination(stream, destination); err != nil { - logger.WithError(err).Error("Failed to send SSH destination") + if err := sendSSHPreamble(stream, destination, token); err != nil { + logger.WithError(err).Error("Failed to send SSH preamble") + return } } @@ -167,16 +171,22 @@ func StartProxyServer(logger *logrus.Logger, listener net.Listener, remote strin return httpServer.Serve(listener) } -// sendSSHDestination sends the final SSH destination address to the cloudflared SSH proxy +// sendSSHPreamble sends the final SSH destination address to the cloudflared SSH proxy // The destination is preceded by its length -func sendSSHDestination(stream net.Conn, destination string) error { - sizeBytes := make([]byte, 4) - binary.BigEndian.PutUint32(sizeBytes, uint32(len(destination))) +func sendSSHPreamble(stream net.Conn, destination, token string) error { + preamble := &sshserver.SSHPreamble{Destination: destination, JWT: token} + payload, err := json.Marshal(preamble) + if err != nil { + return err + } + + sizeBytes := make([]byte, sshserver.SSHPreambleLength) + binary.BigEndian.PutUint32(sizeBytes, uint32(len(payload))) if _, err := stream.Write(sizeBytes); err != nil { return err } - if _, err := stream.Write([]byte(destination)); err != nil { + if _, err := stream.Write(payload); err != nil { return err } return nil