mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 18:19:57 +00:00
AUTH-2018: Adds support for authorized keys and short lived certs
This commit is contained in:
@@ -4,8 +4,8 @@ package sshserver
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -19,18 +19,15 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultShellPrompt = `\e[0;31m[\u@\h \W]\$ \e[m `
|
||||
configDir = "/etc/cloudflared/"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
ssh.Server
|
||||
logger *logrus.Logger
|
||||
shutdownC chan struct{}
|
||||
logger *logrus.Logger
|
||||
shutdownC chan struct{}
|
||||
caCert ssh.PublicKey
|
||||
getUserFunc func(string) (*User, error)
|
||||
}
|
||||
|
||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}) (*SSHServer, error) {
|
||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool) (*SSHServer, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -39,11 +36,28 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}) (*SSHSe
|
||||
return nil, errors.New("cloudflared SSH server needs to run as root")
|
||||
}
|
||||
|
||||
sshServer := SSHServer{ssh.Server{Addr: address}, logger, shutdownC}
|
||||
sshServer := SSHServer{
|
||||
Server: ssh.Server{Addr: address},
|
||||
logger: logger,
|
||||
shutdownC: shutdownC,
|
||||
getUserFunc: lookupUser,
|
||||
}
|
||||
|
||||
if err := sshServer.configureHostKeys(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if shortLivedCertAuth {
|
||||
caCert, err := getCACert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sshServer.caCert = caCert
|
||||
sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler
|
||||
} else {
|
||||
sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler
|
||||
}
|
||||
|
||||
return &sshServer, nil
|
||||
}
|
||||
|
||||
@@ -64,11 +78,9 @@ func (s *SSHServer) Start() error {
|
||||
func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
|
||||
// Get uid and gid of user attempting to login
|
||||
sshUser, err := lookupUser(session.User())
|
||||
if err != nil {
|
||||
if _, err := io.WriteString(session, "Invalid credentials\n"); err != nil {
|
||||
s.logger.WithError(err).Error("Invalid credentials: Failed to write to SSH session")
|
||||
}
|
||||
sshUser, ok := session.Context().Value("sshUser").(*User)
|
||||
if !ok || sshUser == nil {
|
||||
s.logger.Error("Error retrieving credentials from session")
|
||||
s.CloseSession(session)
|
||||
return
|
||||
}
|
||||
@@ -86,17 +98,22 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
return
|
||||
}
|
||||
|
||||
uidInt, uidErr := stringToUint32(sshUser.Uid)
|
||||
gidInt, gidErr := stringToUint32(sshUser.Gid)
|
||||
if uidErr != nil || gidErr != nil {
|
||||
uidInt, err := stringToUint32(sshUser.Uid)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Invalid user")
|
||||
s.CloseSession(session)
|
||||
return
|
||||
}
|
||||
gidInt, err := stringToUint32(sshUser.Gid)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Invalid user group")
|
||||
s.CloseSession(session)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Name))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
|
||||
cmd.Dir = sshUser.HomeDir
|
||||
psuedoTTY, err := pty.Start(cmd)
|
||||
|
Reference in New Issue
Block a user