diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 38da02d5..adf08b53 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "path/filepath" + "regexp" "sort" "strings" "text/tabwriter" @@ -355,50 +356,65 @@ func buildRouteCommand() *cli.Command { Name: "route", Action: cliutil.ErrorHandler(routeCommand), Usage: "Define what hostname or load balancer can route to this tunnel", - Description: `The route defines what hostname or load balancer can route to this tunnel. + Description: `The route defines what hostname or load balancer will proxy requests to this tunnel. - To route a hostname: cloudflared tunnel route dns - To use this tunnel as a load balancer origin: cloudflared tunnel route lb `, + To route a hostname by creating a CNAME to tunnel's address: + cloudflared tunnel route dns + To use this tunnel as a load balancer origin, creating pool and load balancer if necessary: + cloudflared tunnel route lb `, ArgsUsage: "dns|lb TUNNEL HOSTNAME [LB-POOL]", } } -func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { +func dnsRouteFromArg(c *cli.Context) (tunnelstore.Route, error) { const ( userHostnameIndex = 2 - expectArgs = 3 + expectedNArgs = 3 ) - if c.NArg() != expectArgs { - return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg()) + if c.NArg() != expectedNArgs { + return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg()) } userHostname := c.Args().Get(userHostnameIndex) if userHostname == "" { return nil, cliutil.UsageError("The third argument should be the hostname") + } else if !validateName(userHostname) { + return nil, errors.Errorf("%s is not a valid hostname", userHostname) } return tunnelstore.NewDNSRoute(userHostname), nil } -func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { +func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) { const ( lbNameIndex = 2 lbPoolIndex = 3 - expectMinArgs = 3 + expectedNArgs = 4 ) - if c.NArg() < expectMinArgs { - return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg()) + if c.NArg() != expectedNArgs { + return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg()) } lbName := c.Args().Get(lbNameIndex) if lbName == "" { return nil, cliutil.UsageError("The third argument should be the load balancer name") + } else if !validateName(lbName) { + return nil, errors.Errorf("%s is not a valid load balancer name", lbName) } + lbPool := c.Args().Get(lbPoolIndex) if lbPool == "" { - lbPool = defaultPoolName(tunnelID) + return nil, cliutil.UsageError("The fourth argument should be the pool name") + } else if !validateName(lbPool) { + return nil, errors.Errorf("%s is not a valid pool name", lbPool) } return tunnelstore.NewLBRoute(lbName, lbPool), nil } +var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$") + +func validateName(s string) bool { + return nameRegex.MatchString(s) +} + func routeCommand(c *cli.Context) error { if c.NArg() < 2 { return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`) @@ -419,7 +435,7 @@ func routeCommand(c *cli.Context) error { if err != nil { return err } - r, err = dnsRouteFromArg(c, tunnelID) + r, err = dnsRouteFromArg(c) if err != nil { return err } @@ -428,7 +444,7 @@ func routeCommand(c *cli.Context) error { if err != nil { return err } - r, err = lbRouteFromArg(c, tunnelID) + r, err = lbRouteFromArg(c) if err != nil { return err } @@ -443,6 +459,3 @@ func routeCommand(c *cli.Context) error { return nil } -func defaultPoolName(tunnelID uuid.UUID) string { - return fmt.Sprintf("tunnel:%v", tunnelID) -} diff --git a/cmd/cloudflared/tunnel/subcommands_test.go b/cmd/cloudflared/tunnel/subcommands_test.go index 343419ee..fb1bbbdf 100644 --- a/cmd/cloudflared/tunnel/subcommands_test.go +++ b/cmd/cloudflared/tunnel/subcommands_test.go @@ -98,3 +98,27 @@ func TestTunnelfilePath(t *testing.T) { expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID) assert.Equal(t, expected, actual) } + +func TestValidateName(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {name: "", want: false}, + {name: "-", want: false}, + {name: ".", want: false}, + {name: "a b", want: false}, + {name: "a+b", want: false}, + {name: "-ab", want: false}, + + {name: "ab", want: true}, + {name: "ab-c", want: true}, + {name: "abc.def", want: true}, + {name: "_ab_c.-d-ef", want: true}, + } + for _, tt := range tests { + if got := validateName(tt.name); got != tt.want { + t.Errorf("validateName() = %v, want %v", got, tt.want) + } + } +} \ No newline at end of file