AUTH-2089: Revise ssh server to function as a proxy

This commit is contained in:
Michael Borkenstein
2019-09-30 15:44:23 -05:00
parent b3bcce97da
commit dbde3870da
10 changed files with 205 additions and 860 deletions

View File

@@ -4,36 +4,29 @@ package sshserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
gossh "golang.org/x/crypto/ssh"
)
const (
auditEventAuth = "auth"
auditEventStart = "session_start"
auditEventStop = "session_stop"
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
sshContextSessionID = "sessionID"
auditEventStart = "session_start"
auditEventStop = "session_stop"
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
auditEventShell = "shell"
sshContextSessionID = "sessionID"
sshContextEventLogger = "eventLogger"
)
@@ -47,8 +40,7 @@ type auditEvent struct {
IPAddress string `json:"ip_address,omitempty"`
}
// SSHServer adds on to the ssh.Server of the gliderlabs package
type SSHServer struct {
type SSHProxy struct {
ssh.Server
logger *logrus.Logger
shutdownC chan struct{}
@@ -56,62 +48,40 @@ type SSHServer struct {
logManager sshlog.Manager
}
// New creates a new SSHServer and configures its host keys and authenication by the data provided
func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration, enablePortForwarding bool) (*SSHServer, error) {
currentUser, err := user.Current()
if err != nil {
return nil, err
}
if currentUser.Uid != "0" {
return nil, errors.New("cloudflared SSH server needs to run as root")
}
forwardHandler := &ssh.ForwardedTCPHandler{}
sshServer := SSHServer{
Server: ssh.Server{
Addr: address,
MaxTimeout: maxTimeout,
IdleTimeout: idleTimeout,
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
// Register SSH global Request handlers to respond to tcpip forwarding
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
},
// Register SSH channel types
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": ssh.DirectTCPIPHandler,
},
},
// New creates a new SSHProxy and configures its host keys and authentication by the data provided
func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHProxy, error) {
sshProxy := SSHProxy{
logger: logger,
shutdownC: shutdownC,
logManager: logManager,
}
sshProxy.Server = ssh.Server{
Addr: address,
MaxTimeout: maxTimeout,
IdleTimeout: idleTimeout,
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": sshProxy.channelHandler,
},
}
// AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing.
// TODO: Remove this
sshServer.ConnCallback = func(conn net.Conn) net.Conn {
sshProxy.ConnCallback = func(conn net.Conn) net.Conn {
time.Sleep(10 * time.Millisecond)
return conn
}
if enablePortForwarding {
sshServer.LocalPortForwardingCallback = allowForward
sshServer.ReversePortForwardingCallback = allowForward
}
if err := sshServer.configureHostKeys(); err != nil {
if err := sshProxy.configureHostKeys(); err != nil {
return nil, err
}
sshServer.configureAuthentication()
return &sshServer, nil
return &sshProxy, nil
}
// Start the SSH server listener to start handling SSH connections from clients
func (s *SSHServer) Start() error {
// Start the SSH proxy listener to start handling SSH connections from clients
func (s *SSHProxy) Start() error {
s.logger.Infof("Starting SSH server at %s", s.Addr)
go func() {
@@ -121,256 +91,180 @@ func (s *SSHServer) Start() error {
}
}()
s.Handle(s.connectionHandler)
return s.ListenAndServe()
}
func (s *SSHServer) connectionHandler(session ssh.Session) {
sessionUUID, err := uuid.NewRandom()
if err != nil {
if _, err := io.WriteString(session, "Failed to generate session ID\n"); err != nil {
s.logger.WithError(err).Error("Failed to generate session ID: Failed to write to SSH session")
}
s.errorAndExit(session, "", nil)
// channelHandler proxies incoming and outgoing SSH traffic back and forth over an SSH Channel
func (s *SSHProxy) channelHandler(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
if err := s.configureAuditLogger(ctx); err != nil {
s.logger.WithError(err).Error("Failed to configure audit logging")
return
}
clientConfig := &gossh.ClientConfig{
User: conn.User(),
// AUTH-2103 TODO: proper host key check
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
// AUTH-2114 TODO: replace with short lived cert auth
Auth: []gossh.AuthMethod{gossh.Password("test")},
ClientVersion: s.Version,
}
switch newChan.ChannelType() {
case "session":
// Accept incoming channel request from client
localChan, localChanReqs, err := newChan.Accept()
if err != nil {
s.logger.WithError(err).Error("Failed to accept session channel")
return
}
defer localChan.Close()
// AUTH-2088 TODO: retrieve ssh target from tunnel
// Create outgoing ssh connection to destination SSH server
client, err := gossh.Dial("tcp", "localhost:22", clientConfig)
if err != nil {
s.logger.WithError(err).Error("Failed to dial remote server")
return
}
defer client.Close()
// Open channel session channel to destination server
remoteChan, remoteChanReqs, err := client.OpenChannel("session", []byte{})
if err != nil {
s.logger.WithError(err).Error("Failed to open remote channel")
return
}
defer remoteChan.Close()
// Proxy ssh traffic back and forth between client and destination
s.proxyChannel(localChan, remoteChan, localChanReqs, remoteChanReqs, conn, ctx)
}
}
// proxyChannel couples two SSH channels and proxies SSH traffic and channel requests back and forth.
func (s *SSHProxy) proxyChannel(localChan, remoteChan gossh.Channel, localChanReqs, remoteChanReqs <-chan *gossh.Request, conn *gossh.ServerConn, ctx ssh.Context) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(localChan, remoteChan); err != nil {
s.logger.WithError(err).Error("remote to local copy error")
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(remoteChan, localChan); err != nil {
s.logger.WithError(err).Error("local to remote copy error")
}
done <- struct{}{}
}()
s.logAuditEvent(conn, "", auditEventStart, ctx)
defer s.logAuditEvent(conn, "", auditEventStop, ctx)
// Proxy channel requests
for {
select {
case req := <-localChanReqs:
if req == nil {
return
}
if err := s.forwardChannelRequest(remoteChan, req); err != nil {
s.logger.WithError(err).Error("Failed to forward request")
return
}
s.logChannelRequest(req, conn, ctx)
case req := <-remoteChanReqs:
if req == nil {
return
}
if err := s.forwardChannelRequest(localChan, req); err != nil {
s.logger.WithError(err).Error("Failed to forward request")
return
}
case <-done:
return
}
}
}
// forwardChannelRequest sends request req to SSH channel sshChan, waits for reply, and sends the reply back.
func (s *SSHProxy) forwardChannelRequest(sshChan gossh.Channel, req *gossh.Request) error {
reply, err := sshChan.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return errors.Wrap(err, "Failed to send request")
}
if err := req.Reply(reply, nil); err != nil {
return errors.Wrap(err, "Failed to reply to request")
}
return nil
}
// logChannelRequest creates an audit log for different types of channel requests
func (s *SSHProxy) logChannelRequest(req *gossh.Request, conn *gossh.ServerConn, ctx ssh.Context) {
var eventType string
var event string
switch req.Type {
case "exec":
var payload struct{ Value string }
if err := gossh.Unmarshal(req.Payload, &payload); err != nil {
s.logger.WithError(err).Errorf("Failed to unmarshal channel request payload: %s:%s", req.Type, req.Payload)
}
event = payload.Value
eventType = auditEventExec
if strings.HasPrefix(string(req.Payload), "scp") {
eventType = auditEventScp
}
case "shell":
eventType = auditEventShell
case "window-change":
eventType = auditEventResize
}
s.logAuditEvent(conn, event, eventType, ctx)
}
func (s *SSHProxy) configureAuditLogger(ctx ssh.Context) error {
sessionUUID, err := uuid.NewRandom()
if err != nil {
return errors.New("failed to generate session ID")
}
sessionID := sessionUUID.String()
eventLogger, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger)
if err != nil {
if _, err := io.WriteString(session, "Failed to create event log\n"); err != nil {
s.logger.WithError(err).Error("Failed to create event log: Failed to write to create event logger")
}
s.errorAndExit(session, "", nil)
return
return errors.New("failed to create event log")
}
sshContext, ok := session.Context().(ssh.Context)
if !ok {
s.logger.Error("Could not retrieve session context")
s.errorAndExit(session, "", nil)
}
sshContext.SetValue(sshContextSessionID, sessionID)
sshContext.SetValue(sshContextEventLogger, eventLogger)
// Get uid and gid of user attempting to login
sshUser, uidInt, gidInt, success := s.getSSHUser(session, eventLogger)
if !success {
return
}
// Spawn shell under user
var cmd *exec.Cmd
if session.RawCommand() != "" {
cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand())
event := auditEventExec
if strings.HasPrefix(session.RawCommand(), "scp") {
event = auditEventScp
}
s.logAuditEvent(session, event)
} else {
cmd = exec.Command(sshUser.Shell)
s.logAuditEvent(session, auditEventStart)
defer s.logAuditEvent(session, auditEventStop)
}
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true}
cmd.Env = append(cmd.Env, session.Environ()...)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
cmd.Dir = sshUser.HomeDir
var shellInput io.WriteCloser
var shellOutput io.ReadCloser
pr, pw := io.Pipe()
defer pw.Close()
ptyReq, winCh, isPty := session.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
tty, err := s.startPtySession(cmd, winCh, func() {
s.logAuditEvent(session, auditEventResize)
})
shellInput = tty
shellOutput = tty
if err != nil {
s.logger.WithError(err).Error("Failed to start pty session")
close(s.shutdownC)
return
}
} else {
var shellError io.ReadCloser
shellInput, shellOutput, shellError, err = s.startNonPtySession(cmd)
if err != nil {
s.logger.WithError(err).Error("Failed to start non-pty session")
close(s.shutdownC)
return
}
// Write stderr to both the command recorder, and remote user
go func() {
mw := io.MultiWriter(pw, session.Stderr())
if _, err := io.Copy(mw, shellError); err != nil {
s.logger.WithError(err).Error("Failed to write stderr to user")
}
}()
}
sessionLogger, err := s.logManager.NewSessionLogger(fmt.Sprintf("%s-session.log", sessionID), s.logger)
if err != nil {
if _, err := io.WriteString(session, "Failed to create log\n"); err != nil {
s.logger.WithError(err).Error("Failed to create log: Failed to write to SSH session")
}
s.errorAndExit(session, "", nil)
return
}
go func() {
defer sessionLogger.Close()
defer pr.Close()
_, err := io.Copy(sessionLogger, pr)
if err != nil {
s.logger.WithError(err).Error("Failed to write session log")
}
}()
// Write stdin to shell
go func() {
/*
Only close shell stdin for non-pty sessions because they have distinct stdin, stdout, and stderr.
This is done to prevent commands like SCP from hanging after all data has been sent.
PTY sessions share one file for all three streams and the shell process closes it.
Closing it here also closes shellOutput and causes an error on copy().
*/
if !isPty {
defer shellInput.Close()
}
if _, err := io.Copy(shellInput, session); err != nil {
s.logger.WithError(err).Error("Failed to write incoming command to pty")
}
}()
// Write stdout to both the command recorder, and remote user
mw := io.MultiWriter(pw, session)
if _, err := io.Copy(mw, shellOutput); err != nil {
s.logger.WithError(err).Error("Failed to write stdout to user")
}
// Wait for all resources associated with cmd to be released
// Returns error if shell exited with a non-zero status or received a signal
if err := cmd.Wait(); err != nil {
s.logger.WithError(err).Debug("Shell did not close correctly")
}
ctx.SetValue(sshContextSessionID, sessionID)
ctx.SetValue(sshContextEventLogger, eventLogger)
return nil
}
// getSSHUser gets the ssh user, uid, and gid of the user attempting to login
func (s *SSHServer) getSSHUser(session ssh.Session, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) {
// Get uid and gid of user attempting to login
sshUser, ok := session.Context().Value("sshUser").(*User)
if !ok || sshUser == nil {
s.errorAndExit(session, "Error retrieving credentials from session", nil)
return nil, 0, 0, false
}
s.logAuditEvent(session, auditEventAuth)
uidInt, err := stringToUint32(sshUser.Uid)
if err != nil {
s.errorAndExit(session, "Invalid user", err)
return sshUser, 0, 0, false
}
gidInt, err := stringToUint32(sshUser.Gid)
if err != nil {
s.errorAndExit(session, "Invalid user group", err)
return sshUser, 0, 0, false
}
return sshUser, uidInt, gidInt, true
}
// errorAndExit reports an error with the session and exits
func (s *SSHServer) errorAndExit(session ssh.Session, errText string, err error) {
if exitError := session.Exit(1); exitError != nil {
s.logger.WithError(exitError).Error("Failed to close SSH session")
} else if err != nil {
s.logger.WithError(err).Error(errText)
} else if errText != "" {
s.logger.Error(errText)
}
}
func (s *SSHServer) startNonPtySession(cmd *exec.Cmd) (stdin io.WriteCloser, stdout io.ReadCloser, stderr io.ReadCloser, err error) {
stdin, err = cmd.StdinPipe()
if err != nil {
return
}
stdout, err = cmd.StdoutPipe()
if err != nil {
return
}
stderr, err = cmd.StderrPipe()
if err != nil {
return
}
if err = cmd.Start(); err != nil {
return
}
return
}
func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logCallback func()) (io.ReadWriteCloser, error) {
tty, err := pty.Start(cmd)
if err != nil {
return nil, err
}
// Handle terminal window size changes
go func() {
for win := range winCh {
if errNo := setWinsize(tty, win.Width, win.Height); errNo != 0 {
s.logger.WithError(err).Error("Failed to set pty window size")
close(s.shutdownC)
return
}
logCallback()
}
}()
return tty, nil
}
func (s *SSHServer) logAuditEvent(session ssh.Session, eventType string) {
username := "unknown"
sshUser, ok := session.Context().Value("sshUser").(*User)
if ok && sshUser != nil {
username = sshUser.Username
}
sessionID, ok := session.Context().Value(sshContextSessionID).(string)
func (s *SSHProxy) logAuditEvent(conn *gossh.ServerConn, event, eventType string, ctx ssh.Context) {
sessionID, ok := ctx.Value(sshContextSessionID).(string)
if !ok {
s.logger.Error("Failed to retrieve sessionID from context")
return
}
writer, ok := session.Context().Value(sshContextEventLogger).(io.WriteCloser)
writer, ok := ctx.Value(sshContextEventLogger).(io.WriteCloser)
if !ok {
s.logger.Error("Failed to retrieve eventLogger from context")
return
}
event := auditEvent{
Event: session.RawCommand(),
ae := auditEvent{
Event: event,
EventType: eventType,
SessionID: sessionID,
User: username,
Login: username,
User: conn.User(),
Login: conn.User(),
Datetime: time.Now().UTC().Format(time.RFC3339),
IPAddress: session.RemoteAddr().String(),
IPAddress: conn.RemoteAddr().String(),
}
data, err := json.Marshal(&event)
data, err := json.Marshal(&ae)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal audit event. malformed audit object")
return
@@ -381,19 +275,3 @@ func (s *SSHServer) logAuditEvent(session ssh.Session, eventType string) {
}
}
// Sets PTY window size for terminal
func setWinsize(f *os.File, w, h int) syscall.Errno {
_, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
return errNo
}
func stringToUint32(str string) (uint32, error) {
uid, err := strconv.ParseUint(str, 10, 32)
return uint32(uid), err
}
func allowForward(_ ssh.Context, _ string, _ uint32) bool {
return true
}