AUTH-2018: Adds support for authorized keys and short lived certs

This commit is contained in:
Michael Borkenstein
2019-08-22 11:36:21 -05:00
parent df25ed9bde
commit baec3e289e
16 changed files with 549 additions and 33 deletions

View File

@@ -4,8 +4,8 @@ package sshserver
import (
"bufio"
"errors"
"fmt"
"github.com/pkg/errors"
"io"
"os"
"os/exec"
@@ -19,18 +19,15 @@ import (
"github.com/sirupsen/logrus"
)
const (
defaultShellPrompt = `\e[0;31m[\u@\h \W]\$ \e[m `
configDir = "/etc/cloudflared/"
)
type SSHServer struct {
ssh.Server
logger *logrus.Logger
shutdownC chan struct{}
logger *logrus.Logger
shutdownC chan struct{}
caCert ssh.PublicKey
getUserFunc func(string) (*User, error)
}
func New(logger *logrus.Logger, address string, shutdownC chan struct{}) (*SSHServer, error) {
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool) (*SSHServer, error) {
currentUser, err := user.Current()
if err != nil {
return nil, err
@@ -39,11 +36,28 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}) (*SSHSe
return nil, errors.New("cloudflared SSH server needs to run as root")
}
sshServer := SSHServer{ssh.Server{Addr: address}, logger, shutdownC}
sshServer := SSHServer{
Server: ssh.Server{Addr: address},
logger: logger,
shutdownC: shutdownC,
getUserFunc: lookupUser,
}
if err := sshServer.configureHostKeys(); err != nil {
return nil, err
}
if shortLivedCertAuth {
caCert, err := getCACert()
if err != nil {
return nil, err
}
sshServer.caCert = caCert
sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler
} else {
sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler
}
return &sshServer, nil
}
@@ -64,11 +78,9 @@ func (s *SSHServer) Start() error {
func (s *SSHServer) connectionHandler(session ssh.Session) {
// Get uid and gid of user attempting to login
sshUser, err := lookupUser(session.User())
if err != nil {
if _, err := io.WriteString(session, "Invalid credentials\n"); err != nil {
s.logger.WithError(err).Error("Invalid credentials: Failed to write to SSH session")
}
sshUser, ok := session.Context().Value("sshUser").(*User)
if !ok || sshUser == nil {
s.logger.Error("Error retrieving credentials from session")
s.CloseSession(session)
return
}
@@ -86,17 +98,22 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
return
}
uidInt, uidErr := stringToUint32(sshUser.Uid)
gidInt, gidErr := stringToUint32(sshUser.Gid)
if uidErr != nil || gidErr != nil {
uidInt, err := stringToUint32(sshUser.Uid)
if err != nil {
s.logger.WithError(err).Error("Invalid user")
s.CloseSession(session)
return
}
gidInt, err := stringToUint32(sshUser.Gid)
if err != nil {
s.logger.WithError(err).Error("Invalid user group")
s.CloseSession(session)
return
}
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Name))
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
cmd.Dir = sshUser.HomeDir
psuedoTTY, err := pty.Start(cmd)