mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 01:09:57 +00:00
AUTH-2394 added socks5 proxy
This commit is contained in:
77
socks/auth_handler.go
Normal file
77
socks/auth_handler.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// NoAuth means no authentication is used when connecting
|
||||
NoAuth = uint8(0)
|
||||
|
||||
// UserPassAuth means a user/password is used when connecting
|
||||
UserPassAuth = uint8(2)
|
||||
|
||||
noAcceptable = uint8(255)
|
||||
userAuthVersion = uint8(1)
|
||||
authSuccess = uint8(0)
|
||||
authFailure = uint8(1)
|
||||
)
|
||||
|
||||
// AuthHandler handles socks authenication requests
|
||||
type AuthHandler interface {
|
||||
Handle(io.Reader, io.Writer) error
|
||||
Register(uint8, Authenticator)
|
||||
}
|
||||
|
||||
// StandardAuthHandler loads the default authenticators
|
||||
type StandardAuthHandler struct {
|
||||
authenticators map[uint8]Authenticator
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a default auth handler
|
||||
func NewAuthHandler() AuthHandler {
|
||||
defaults := make(map[uint8]Authenticator)
|
||||
defaults[NoAuth] = NewNoAuthAuthenticator()
|
||||
return &StandardAuthHandler{
|
||||
authenticators: defaults,
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds/replaces an Authenticator to use when handling Authentication requests
|
||||
func (h *StandardAuthHandler) Register(method uint8, a Authenticator) {
|
||||
h.authenticators[method] = a
|
||||
}
|
||||
|
||||
// Handle gets the methods from the SOCKS5 client and authenicates with the first supported method
|
||||
func (h *StandardAuthHandler) Handle(bufConn io.Reader, conn io.Writer) error {
|
||||
methods, err := readMethods(bufConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read auth methods: %v", err)
|
||||
}
|
||||
|
||||
// first supported method is used
|
||||
for _, method := range methods {
|
||||
authenticator := h.authenticators[method]
|
||||
if authenticator != nil {
|
||||
return authenticator.Handle(bufConn, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// failed to authenticate. No supported authentication type found
|
||||
conn.Write([]byte{socks5Version, noAcceptable})
|
||||
return fmt.Errorf("unknown authentication type")
|
||||
}
|
||||
|
||||
// readMethods is used to read the number and type of methods
|
||||
func readMethods(r io.Reader) ([]byte, error) {
|
||||
header := []byte{0}
|
||||
if _, err := r.Read(header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numMethods := int(header[0])
|
||||
methods := make([]byte, numMethods)
|
||||
_, err := io.ReadAtLeast(r, methods, numMethods)
|
||||
return methods, err
|
||||
}
|
87
socks/authenticator.go
Normal file
87
socks/authenticator.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Authenticator is the connection passed in as a reader/writer to support different authentication types
|
||||
type Authenticator interface {
|
||||
Handle(io.Reader, io.Writer) error
|
||||
}
|
||||
|
||||
// NoAuthAuthenticator is used to handle the No Authentication mode
|
||||
type NoAuthAuthenticator struct{}
|
||||
|
||||
// NewNoAuthAuthenticator creates a authless Authenticator
|
||||
func NewNoAuthAuthenticator() Authenticator {
|
||||
return &NoAuthAuthenticator{}
|
||||
}
|
||||
|
||||
// Handle writes back the version and NoAuth
|
||||
func (a *NoAuthAuthenticator) Handle(reader io.Reader, writer io.Writer) error {
|
||||
_, err := writer.Write([]byte{socks5Version, NoAuth})
|
||||
return err
|
||||
}
|
||||
|
||||
// UserPassAuthAuthenticator is used to handle the user/password mode
|
||||
type UserPassAuthAuthenticator struct {
|
||||
IsValid func(string, string) bool
|
||||
}
|
||||
|
||||
// NewUserPassAuthAuthenticator creates a new username/password validator Authenticator
|
||||
func NewUserPassAuthAuthenticator(isValid func(string, string) bool) Authenticator {
|
||||
return &UserPassAuthAuthenticator{
|
||||
IsValid: isValid,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle writes back the version and NoAuth
|
||||
func (a *UserPassAuthAuthenticator) Handle(reader io.Reader, writer io.Writer) error {
|
||||
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the version and username length
|
||||
header := []byte{0, 0}
|
||||
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure compatibility. Someone call E-harmony
|
||||
if header[0] != userAuthVersion {
|
||||
return fmt.Errorf("Unsupported auth version: %v", header[0])
|
||||
}
|
||||
|
||||
// Get the user name
|
||||
userLen := int(header[1])
|
||||
user := make([]byte, userLen)
|
||||
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the password length
|
||||
if _, err := reader.Read(header[:1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the password
|
||||
passLen := int(header[0])
|
||||
pass := make([]byte, passLen)
|
||||
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the password
|
||||
if a.IsValid(string(user), string(pass)) {
|
||||
_, err := writer.Write([]byte{userAuthVersion, authSuccess})
|
||||
return err
|
||||
}
|
||||
|
||||
// password failed. Write back failure
|
||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("User authentication failed")
|
||||
}
|
57
socks/connection_handler.go
Normal file
57
socks/connection_handler.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ConnectionHandler is the Serve method to handle connections
|
||||
// from a local TCP listener of the standard library (net.Listener)
|
||||
type ConnectionHandler interface {
|
||||
Serve(io.ReadWriter) error
|
||||
}
|
||||
|
||||
// StandardConnectionHandler is the base implementation of handling SOCKS5 requests
|
||||
type StandardConnectionHandler struct {
|
||||
requestHandler RequestHandler
|
||||
authHandler AuthHandler
|
||||
}
|
||||
|
||||
// NewConnectionHandler creates a standard SOCKS5 connection handler
|
||||
// This process connections from a generic TCP listener from the standard library
|
||||
func NewConnectionHandler(requestHandler RequestHandler) ConnectionHandler {
|
||||
return &StandardConnectionHandler{
|
||||
requestHandler: requestHandler,
|
||||
authHandler: NewAuthHandler(),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve process new connection created after calling `Accept()` in the standard library
|
||||
func (h *StandardConnectionHandler) Serve(c io.ReadWriter) error {
|
||||
bufConn := bufio.NewReader(c)
|
||||
|
||||
// read the version byte
|
||||
version := []byte{0}
|
||||
if _, err := bufConn.Read(version); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure compatibility
|
||||
if version[0] != socks5Version {
|
||||
return fmt.Errorf("Unsupported SOCKS version: %v", version)
|
||||
}
|
||||
|
||||
// handle auth
|
||||
if err := h.authHandler.Handle(bufConn, c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// process command/request
|
||||
req, err := NewRequest(bufConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.requestHandler.Handle(req, c)
|
||||
}
|
88
socks/connection_handler_test.go
Normal file
88
socks/connection_handler_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type successResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func sendSocksRequest(t *testing.T) []byte {
|
||||
dialer, err := proxy.SOCKS5("tcp", "127.0.0.1:8086", nil, proxy.Direct)
|
||||
assert.NoError(t, err)
|
||||
|
||||
httpTransport := &http.Transport{}
|
||||
httpClient := &http.Client{Transport: httpTransport}
|
||||
// set our socks5 as the dialer
|
||||
httpTransport.Dial = dialer.Dial
|
||||
|
||||
req, err := http.NewRequest("GET", "http://127.0.0.1:8085", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T, httpHandler func(w http.ResponseWriter, r *http.Request)) {
|
||||
// create a socks server
|
||||
requestHandler := NewRequestHandler(NewNetDialer())
|
||||
socksServer := NewConnectionHandler(requestHandler)
|
||||
listener, err := net.Listen("tcp", ":8086")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
for {
|
||||
conn, _ := listener.Accept()
|
||||
go socksServer.Serve(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// create an http server
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", httpHandler)
|
||||
|
||||
// start the servers
|
||||
go http.ListenAndServe(":8085", mux)
|
||||
|
||||
}
|
||||
|
||||
func respondWithJSON(w http.ResponseWriter, v interface{}, status int) {
|
||||
data, _ := json.Marshal(v)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func OkJSONResponseHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := successResponse{
|
||||
Status: "ok",
|
||||
}
|
||||
respondWithJSON(w, resp, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestSocksConnection(t *testing.T) {
|
||||
startTestServer(t, OkJSONResponseHandler)
|
||||
b := sendSocksRequest(t)
|
||||
assert.True(t, len(b) > 0, "no data returned!")
|
||||
|
||||
var resp successResponse
|
||||
json.Unmarshal(b, &resp)
|
||||
|
||||
assert.True(t, resp.Status == "ok", "response didn't return ok")
|
||||
}
|
57
socks/dialer.go
Normal file
57
socks/dialer.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Dialer is used to provided the transport of the proxy
|
||||
type Dialer interface {
|
||||
Dial(string) (io.ReadWriteCloser, *AddrSpec, error)
|
||||
}
|
||||
|
||||
// NetDialer is a standard TCP dialer
|
||||
type NetDialer struct {
|
||||
}
|
||||
|
||||
// NewNetDialer creates a new dialer
|
||||
func NewNetDialer() Dialer {
|
||||
return &NetDialer{}
|
||||
}
|
||||
|
||||
// Dial is a base TCP dialer
|
||||
func (d *NetDialer) Dial(address string) (io.ReadWriteCloser, *AddrSpec, error) {
|
||||
c, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
local := c.LocalAddr().(*net.TCPAddr)
|
||||
addr := AddrSpec{IP: local.IP, Port: local.Port}
|
||||
|
||||
return c, &addr, nil
|
||||
}
|
||||
|
||||
// ConnDialer is like NetDialer but with an existing TCP dialer already created
|
||||
type ConnDialer struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// NewConnDialer creates a new dialer with a already created net.conn (TCP expected)
|
||||
func NewConnDialer(conn net.Conn) Dialer {
|
||||
return &ConnDialer{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial is a TCP dialer but already created
|
||||
func (d *ConnDialer) Dial(address string) (io.ReadWriteCloser, *AddrSpec, error) {
|
||||
local, ok := d.conn.LocalAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("not a tcp connection")
|
||||
}
|
||||
|
||||
addr := AddrSpec{IP: local.IP, Port: local.Port}
|
||||
return d.conn, &addr, nil
|
||||
}
|
195
socks/request.go
Normal file
195
socks/request.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
// version
|
||||
socks5Version = uint8(5)
|
||||
|
||||
// commands https://tools.ietf.org/html/rfc1928#section-4
|
||||
connectCommand = uint8(1)
|
||||
bindCommand = uint8(2)
|
||||
associateCommand = uint8(3)
|
||||
|
||||
// address types
|
||||
ipv4Address = uint8(1)
|
||||
fqdnAddress = uint8(3)
|
||||
ipv6Address = uint8(4)
|
||||
)
|
||||
|
||||
// https://tools.ietf.org/html/rfc1928#section-6
|
||||
const (
|
||||
successReply uint8 = iota
|
||||
serverFailure
|
||||
ruleFailure
|
||||
networkUnreachable
|
||||
hostUnreachable
|
||||
connectionRefused
|
||||
ttlExpired
|
||||
commandNotSupported
|
||||
addrTypeNotSupported
|
||||
)
|
||||
|
||||
// AddrSpec is used to return the target IPv4, IPv6, or a FQDN
|
||||
type AddrSpec struct {
|
||||
FQDN string
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
// String gives a host version of the Address
|
||||
func (a *AddrSpec) String() string {
|
||||
if a.FQDN != "" {
|
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||
}
|
||||
|
||||
// Address returns a string suitable to dial; prefer returning IP-based
|
||||
// address, fallback to FQDN
|
||||
func (a AddrSpec) Address() string {
|
||||
if len(a.IP) != 0 {
|
||||
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
|
||||
}
|
||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
||||
}
|
||||
|
||||
// Request is a SOCKS5 command with supporting field of the connection
|
||||
type Request struct {
|
||||
// Protocol version
|
||||
Version uint8
|
||||
// Requested command
|
||||
Command uint8
|
||||
// AddrSpec of the destination
|
||||
DestAddr *AddrSpec
|
||||
// reading from the connection
|
||||
bufConn io.Reader
|
||||
}
|
||||
|
||||
// NewRequest creates a new request from the connection data stream
|
||||
func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||
// Read the version byte
|
||||
header := []byte{0, 0, 0}
|
||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
||||
return nil, fmt.Errorf("Failed to get command version: %v", err)
|
||||
}
|
||||
|
||||
// ensure compatibility
|
||||
if header[0] != socks5Version {
|
||||
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
|
||||
}
|
||||
|
||||
// Read in the destination address
|
||||
dest, err := readAddrSpec(bufConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Request{
|
||||
Version: socks5Version,
|
||||
Command: header[1],
|
||||
DestAddr: dest,
|
||||
bufConn: bufConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
|
||||
var addrType uint8
|
||||
var addrBody []byte
|
||||
var addrPort uint16
|
||||
switch {
|
||||
case addr == nil:
|
||||
addrType = ipv4Address
|
||||
addrBody = []byte{0, 0, 0, 0}
|
||||
addrPort = 0
|
||||
|
||||
case addr.FQDN != "":
|
||||
addrType = fqdnAddress
|
||||
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
|
||||
addrPort = uint16(addr.Port)
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
addrType = ipv4Address
|
||||
addrBody = []byte(addr.IP.To4())
|
||||
addrPort = uint16(addr.Port)
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
addrType = ipv6Address
|
||||
addrBody = []byte(addr.IP.To16())
|
||||
addrPort = uint16(addr.Port)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Failed to format address: %v", addr)
|
||||
}
|
||||
|
||||
// Format the message
|
||||
msg := make([]byte, 6+len(addrBody))
|
||||
msg[0] = socks5Version
|
||||
msg[1] = resp
|
||||
msg[2] = 0 // Reserved
|
||||
msg[3] = addrType
|
||||
copy(msg[4:], addrBody)
|
||||
msg[4+len(addrBody)] = byte(addrPort >> 8)
|
||||
msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
|
||||
|
||||
// Send the message
|
||||
_, err := w.Write(msg)
|
||||
return err
|
||||
}
|
||||
|
||||
// readAddrSpec is used to read AddrSpec.
|
||||
// Expects an address type byte, followed by the address and port
|
||||
func readAddrSpec(r io.Reader) (*AddrSpec, error) {
|
||||
d := &AddrSpec{}
|
||||
|
||||
// Get the address type
|
||||
addrType := []byte{0}
|
||||
if _, err := r.Read(addrType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle on a per type basis
|
||||
switch addrType[0] {
|
||||
case ipv4Address:
|
||||
addr := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.IP = net.IP(addr)
|
||||
|
||||
case ipv6Address:
|
||||
addr := make([]byte, 16)
|
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.IP = net.IP(addr)
|
||||
|
||||
case fqdnAddress:
|
||||
if _, err := r.Read(addrType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrLen := int(addrType[0])
|
||||
fqdn := make([]byte, addrLen)
|
||||
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.FQDN = string(fqdn)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("Unrecognized address type")
|
||||
}
|
||||
|
||||
// Read the port
|
||||
port := []byte{0, 0}
|
||||
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Port = (int(port[0]) << 8) | int(port[1])
|
||||
|
||||
return d, nil
|
||||
}
|
106
socks/request_handler.go
Normal file
106
socks/request_handler.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RequestHandler is the functions needed to handle a SOCKS5 command
|
||||
type RequestHandler interface {
|
||||
Handle(*Request, io.ReadWriter) error
|
||||
}
|
||||
|
||||
// StandardRequestHandler implements the base socks5 command processing
|
||||
type StandardRequestHandler struct {
|
||||
dialer Dialer
|
||||
}
|
||||
|
||||
// NewRequestHandler creates a standard SOCKS5 request handler
|
||||
// This handles the SOCKS5 commands and proxies them to their destination
|
||||
func NewRequestHandler(dialer Dialer) RequestHandler {
|
||||
return &StandardRequestHandler{
|
||||
dialer: dialer,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle processes and responds to socks5 commands
|
||||
func (h *StandardRequestHandler) Handle(req *Request, conn io.ReadWriter) error {
|
||||
switch req.Command {
|
||||
case connectCommand:
|
||||
return h.handleConnect(conn, req)
|
||||
case bindCommand:
|
||||
return h.handleBind(conn, req)
|
||||
case associateCommand:
|
||||
return h.handleAssociate(conn, req)
|
||||
default:
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Unsupported command: %v", req.Command)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (h *StandardRequestHandler) handleConnect(conn io.ReadWriter, req *Request) error {
|
||||
target, localAddr, err := h.dialer.Dial(req.DestAddr.Address())
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
resp := hostUnreachable
|
||||
if strings.Contains(msg, "refused") {
|
||||
resp = connectionRefused
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
resp = networkUnreachable
|
||||
}
|
||||
if err := sendReply(conn, resp, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
// Send success
|
||||
if err := sendReply(conn, successReply, localAddr); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
|
||||
// Start proxying
|
||||
proxyDone := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, e := io.Copy(target, req.bufConn)
|
||||
proxyDone <- e
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, e := io.Copy(conn, target)
|
||||
proxyDone <- e
|
||||
}()
|
||||
|
||||
// Wait for both
|
||||
for i := 0; i < 2; i++ {
|
||||
e := <-proxyDone
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleBind is used to handle a bind command
|
||||
// TODO: Support bind command
|
||||
func (h *StandardRequestHandler) handleBind(conn io.ReadWriter, req *Request) error {
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
// TODO: Support associate command
|
||||
func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Request) error {
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
28
socks/request_handler_test.go
Normal file
28
socks/request_handler_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnsupportedBind(t *testing.T) {
|
||||
req := createRequest(t, socks5Version, bindCommand, "2001:db8::68", 1337, false)
|
||||
var b bytes.Buffer
|
||||
|
||||
requestHandler := NewRequestHandler(NewNetDialer())
|
||||
err := requestHandler.Handle(req, &b)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, b.Bytes()[1] == commandNotSupported, "expected a response")
|
||||
}
|
||||
|
||||
func TestUnsupportedAssociate(t *testing.T) {
|
||||
req := createRequest(t, socks5Version, associateCommand, "127.0.0.1", 1337, false)
|
||||
var b bytes.Buffer
|
||||
|
||||
requestHandler := NewRequestHandler(NewNetDialer())
|
||||
err := requestHandler.Handle(req, &b)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, b.Bytes()[1] == commandNotSupported, "expected a response")
|
||||
}
|
69
socks/request_test.go
Normal file
69
socks/request_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createRequestData(version, command uint8, ip net.IP, port uint16) []byte {
|
||||
// set the command
|
||||
b := []byte{version, command, 0}
|
||||
|
||||
// append the ip
|
||||
if len(ip) == net.IPv4len {
|
||||
b = append(b, 1)
|
||||
b = append(b, ip.To4()...)
|
||||
} else {
|
||||
b = append(b, 4)
|
||||
b = append(b, ip.To16()...)
|
||||
}
|
||||
|
||||
// append the port
|
||||
p := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(p, port)
|
||||
b = append(b, p...)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func createRequest(t *testing.T, version, command uint8, ipStr string, port uint16, shouldFail bool) *Request {
|
||||
ip := net.ParseIP(ipStr)
|
||||
data := createRequestData(version, command, ip, port)
|
||||
reader := bytes.NewReader(data)
|
||||
req, err := NewRequest(reader)
|
||||
if shouldFail {
|
||||
assert.Error(t, err)
|
||||
return nil
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, req.Version == socks5Version, "version doesn't match expectation: %v", req.Version)
|
||||
assert.True(t, req.Command == command, "command doesn't match expectation: %v", req.Command)
|
||||
assert.True(t, req.DestAddr.Port == int(port), "port doesn't match expectation: %v", req.DestAddr.Port)
|
||||
assert.True(t, req.DestAddr.IP.String() == ipStr, "ip doesn't match expectation: %v", req.DestAddr.IP.String())
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func TestValidConnectRequest(t *testing.T) {
|
||||
createRequest(t, socks5Version, connectCommand, "127.0.0.1", 1337, false)
|
||||
}
|
||||
|
||||
func TestValidBindRequest(t *testing.T) {
|
||||
createRequest(t, socks5Version, bindCommand, "2001:db8::68", 1337, false)
|
||||
}
|
||||
|
||||
func TestValidAssociateRequest(t *testing.T) {
|
||||
createRequest(t, socks5Version, associateCommand, "127.0.0.1", 1234, false)
|
||||
}
|
||||
|
||||
func TestInValidVersionRequest(t *testing.T) {
|
||||
createRequest(t, 4, connectCommand, "127.0.0.1", 1337, true)
|
||||
}
|
||||
|
||||
func TestInValidIPRequest(t *testing.T) {
|
||||
createRequest(t, 4, connectCommand, "127.0.01", 1337, true)
|
||||
}
|
Reference in New Issue
Block a user