mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 17:59:58 +00:00
AUTH-2014: Checks users login shell
This commit is contained in:
@@ -64,20 +64,17 @@ func (s *SSHServer) Start() error {
|
||||
func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
|
||||
// Get uid and gid of user attempting to login
|
||||
uid, gid, err := getUser(session.User())
|
||||
sshUser, err := lookupUser(session.User())
|
||||
if err != nil {
|
||||
if _, err := io.WriteString(session, "Invalid credentials\n"); err != nil {
|
||||
s.logger.WithError(err).Error("Invalid credentials: Failed to write to SSH session")
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to close SSH session")
|
||||
}
|
||||
s.CloseSession(session)
|
||||
return
|
||||
}
|
||||
|
||||
// Spawn shell under user
|
||||
cmd := exec.Command("/bin/bash")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uid, Gid: gid}}
|
||||
cmd := exec.Command(sshUser.Shell)
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if !isPty {
|
||||
@@ -85,20 +82,27 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
s.logger.WithError(err).Error("No PTY requested: Failed to write to SSH session")
|
||||
}
|
||||
|
||||
if err := session.Exit(1); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to close SSH session")
|
||||
}
|
||||
s.CloseSession(session)
|
||||
return
|
||||
}
|
||||
|
||||
uidInt, uidErr := stringToUint32(sshUser.Uid)
|
||||
gidInt, gidErr := stringToUint32(sshUser.Gid)
|
||||
if uidErr != nil || gidErr != nil {
|
||||
s.logger.WithError(err).Error("Invalid user")
|
||||
s.CloseSession(session)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", defaultShellPrompt))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Name))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
|
||||
cmd.Dir = sshUser.HomeDir
|
||||
psuedoTTY, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Error("Failed to start pty session")
|
||||
if err := session.Exit(1); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to close SSH session")
|
||||
}
|
||||
s.CloseSession(session)
|
||||
close(s.shutdownC)
|
||||
return
|
||||
}
|
||||
@@ -108,9 +112,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
for win := range winCh {
|
||||
if errNo := setWinsize(psuedoTTY, win.Width, win.Height); errNo != 0 {
|
||||
s.logger.WithError(err).Error("Failed to set pty window size: ", err.Error())
|
||||
if err := session.Exit(1); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to close SSH session")
|
||||
}
|
||||
s.CloseSession(session)
|
||||
close(s.shutdownC)
|
||||
return
|
||||
}
|
||||
@@ -152,6 +154,12 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHServer) CloseSession(session ssh.Session) {
|
||||
if err := session.Exit(1); err != nil {
|
||||
s.logger.WithError(err).Error("Failed to close SSH session")
|
||||
}
|
||||
}
|
||||
|
||||
// Sets PTY window size for terminal
|
||||
func setWinsize(f *os.File, w, h int) syscall.Errno {
|
||||
_, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||
@@ -159,19 +167,8 @@ func setWinsize(f *os.File, w, h int) syscall.Errno {
|
||||
return errNo
|
||||
}
|
||||
|
||||
// Only works on POSIX systems
|
||||
func getUser(username string) (uint32, uint32, error) {
|
||||
sshUser, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
uid, err := strconv.ParseUint(sshUser.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
gid, err := strconv.ParseUint(sshUser.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return uint32(uid), uint32(gid), nil
|
||||
func stringToUint32(str string) (uint32, error) {
|
||||
uid, err := strconv.ParseUint(str, 10, 32)
|
||||
return uint32(uid), err
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user