TUN-528: Move cloudflared into a separate repo

This commit is contained in:
Areg Harutyunyan
2018-05-01 18:45:06 -05:00
parent e8c621a648
commit d06fc520c7
4726 changed files with 1763680 additions and 0 deletions

View File

@@ -0,0 +1,202 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/mholt/caddy"
)
// SetupIfMatcher parses `if` or `if_op` in the current dispenser block.
// It returns a RequestMatcher and an error if any.
func SetupIfMatcher(controller *caddy.Controller) (RequestMatcher, error) {
var c = controller.Dispenser // copy the dispenser
var matcher IfMatcher
for c.NextBlock() {
switch c.Val() {
case "if":
args1 := c.RemainingArgs()
if len(args1) != 3 {
return matcher, c.ArgErr()
}
ifc, err := newIfCond(args1[0], args1[1], args1[2])
if err != nil {
return matcher, err
}
matcher.ifs = append(matcher.ifs, ifc)
matcher.Enabled = true
case "if_op":
if !c.NextArg() {
return matcher, c.ArgErr()
}
switch c.Val() {
case "and":
matcher.isOr = false
case "or":
matcher.isOr = true
default:
return matcher, c.ArgErr()
}
}
}
return matcher, nil
}
// operators
const (
isOp = "is"
notOp = "not"
hasOp = "has"
startsWithOp = "starts_with"
endsWithOp = "ends_with"
matchOp = "match"
)
// ifCondition is a 'if' condition.
type ifFunc func(a, b string) bool
// ifCond is statement for a IfMatcher condition.
type ifCond struct {
a string
op string
b string
neg bool
rex *regexp.Regexp
f ifFunc
}
// newIfCond creates a new If condition.
func newIfCond(a, op, b string) (ifCond, error) {
i := ifCond{a: a, op: op, b: b}
if strings.HasPrefix(op, "not_") {
i.neg = true
i.op = op[4:]
}
switch i.op {
case isOp:
// It checks for equality.
i.f = i.isFunc
case notOp:
// It checks for inequality.
i.f = i.notFunc
case hasOp:
// It checks if b is a substring of a.
i.f = strings.Contains
case startsWithOp:
// It checks if b is a prefix of a.
i.f = strings.HasPrefix
case endsWithOp:
// It checks if b is a suffix of a.
i.f = strings.HasSuffix
case matchOp:
// It does regexp matching of a against pattern in b and returns if they match.
var err error
if i.rex, err = regexp.Compile(i.b); err != nil {
return ifCond{}, fmt.Errorf("Invalid regular expression: '%s', %v", i.b, err)
}
i.f = i.matchFunc
default:
return ifCond{}, fmt.Errorf("Invalid operator %v", i.op)
}
return i, nil
}
// isFunc is condition for Is operator.
func (i ifCond) isFunc(a, b string) bool {
return a == b
}
// notFunc is condition for Not operator.
func (i ifCond) notFunc(a, b string) bool {
return a != b
}
// matchFunc is condition for Match operator.
func (i ifCond) matchFunc(a, b string) bool {
return i.rex.MatchString(a)
}
// True returns true if the condition is true and false otherwise.
// If r is not nil, it replaces placeholders before comparison.
func (i ifCond) True(r *http.Request) bool {
if i.f != nil {
a, b := i.a, i.b
if r != nil {
replacer := NewReplacer(r, nil, "")
a = replacer.Replace(i.a)
if i.op != matchOp {
b = replacer.Replace(i.b)
}
}
if i.neg {
return !i.f(a, b)
}
return i.f(a, b)
}
return i.neg // false if not negated, true otherwise
}
// IfMatcher is a RequestMatcher for 'if' conditions.
type IfMatcher struct {
Enabled bool // if true, matcher has been configured; otherwise it's no-op
ifs []ifCond // list of If
isOr bool // if true, conditions are 'or' instead of 'and'
}
// Match satisfies RequestMatcher interface.
// It returns true if the conditions in m are true.
func (m IfMatcher) Match(r *http.Request) bool {
if m.isOr {
return m.Or(r)
}
return m.And(r)
}
// And returns true if all conditions in m are true.
func (m IfMatcher) And(r *http.Request) bool {
for _, i := range m.ifs {
if !i.True(r) {
return false
}
}
return true
}
// Or returns true if any of the conditions in m is true.
func (m IfMatcher) Or(r *http.Request) bool {
for _, i := range m.ifs {
if i.True(r) {
return true
}
}
return false
}
// IfMatcherKeyword checks if the next value in the dispenser is a keyword for 'if' config block.
// If true, remaining arguments in the dispinser are cleard to keep the dispenser valid for use.
func IfMatcherKeyword(c *caddy.Controller) bool {
if c.Val() == "if" || c.Val() == "if_op" {
// clear remaining args
c.RemainingArgs()
return true
}
return false
}

View File

@@ -0,0 +1,368 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"context"
"net/http"
"regexp"
"strings"
"testing"
"github.com/mholt/caddy"
)
func TestConditions(t *testing.T) {
tests := []struct {
condition string
isTrue bool
shouldErr bool
}{
{"a is b", false, false},
{"a is a", true, false},
{"a not b", true, false},
{"a not a", false, false},
{"a has a", true, false},
{"a has b", false, false},
{"ba has b", true, false},
{"bab has b", true, false},
{"bab has bb", false, false},
{"a not_has a", false, false},
{"a not_has b", true, false},
{"ba not_has b", false, false},
{"bab not_has b", false, false},
{"bab not_has bb", true, false},
{"bab starts_with bb", false, false},
{"bab starts_with ba", true, false},
{"bab starts_with bab", true, false},
{"bab not_starts_with bb", true, false},
{"bab not_starts_with ba", false, false},
{"bab not_starts_with bab", false, false},
{"bab ends_with bb", false, false},
{"bab ends_with bab", true, false},
{"bab ends_with ab", true, false},
{"bab not_ends_with bb", true, false},
{"bab not_ends_with ab", false, false},
{"bab not_ends_with bab", false, false},
{"a match *", false, true},
{"a match a", true, false},
{"a match .*", true, false},
{"a match a.*", true, false},
{"a match b.*", false, false},
{"ba match b.*", true, false},
{"ba match b[a-z]", true, false},
{"b0 match b[a-z]", false, false},
{"b0a match b[a-z]", false, false},
{"b0a match b[a-z]+", false, false},
{"b0a match b[a-z0-9]+", true, false},
{"bac match b[a-z]{2}", true, false},
{"a not_match *", false, true},
{"a not_match a", false, false},
{"a not_match .*", false, false},
{"a not_match a.*", false, false},
{"a not_match b.*", true, false},
{"ba not_match b.*", false, false},
{"ba not_match b[a-z]", false, false},
{"b0 not_match b[a-z]", true, false},
{"b0a not_match b[a-z]", true, false},
{"b0a not_match b[a-z]+", true, false},
{"b0a not_match b[a-z0-9]+", false, false},
{"bac not_match b[a-z]{2}", false, false},
}
for i, test := range tests {
str := strings.Fields(test.condition)
ifCond, err := newIfCond(str[0], str[1], str[2])
if err != nil {
if !test.shouldErr {
t.Error(err)
}
continue
}
isTrue := ifCond.True(nil)
if isTrue != test.isTrue {
t.Errorf("Test %d: '%s' expected %v found %v", i, test.condition, test.isTrue, isTrue)
}
}
invalidOperators := []string{"ss", "and", "if"}
for _, op := range invalidOperators {
_, err := newIfCond("a", op, "b")
if err == nil {
t.Errorf("Invalid operator %v used, expected error.", op)
}
}
replaceTests := []struct {
url string
condition string
isTrue bool
}{
{"/home", "{uri} match /home", true},
{"/hom", "{uri} match /home", false},
{"/hom", "{uri} starts_with /home", false},
{"/hom", "{uri} starts_with /h", true},
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
}
for i, test := range replaceTests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Errorf("Test %d: failed to create request: %v", i, err)
continue
}
ctx := context.WithValue(r.Context(), OriginalURLCtxKey, *r.URL)
r = r.WithContext(ctx)
str := strings.Fields(test.condition)
ifCond, err := newIfCond(str[0], str[1], str[2])
if err != nil {
t.Errorf("Test %d: failed to create 'if' condition %v", i, err)
continue
}
isTrue := ifCond.True(r)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
continue
}
}
}
func TestIfMatcher(t *testing.T) {
tests := []struct {
conditions []string
isOr bool
isTrue bool
}{
{
[]string{
"a is a",
"b is b",
"c is c",
},
false,
true,
},
{
[]string{
"a is b",
"b is c",
"c is c",
},
true,
true,
},
{
[]string{
"a is a",
"b is a",
"c is c",
},
false,
false,
},
{
[]string{
"a is b",
"b is c",
"c is a",
},
true,
false,
},
{
[]string{},
false,
true,
},
{
[]string{},
true,
false,
},
}
for i, test := range tests {
matcher := IfMatcher{isOr: test.isOr}
for _, condition := range test.conditions {
str := strings.Fields(condition)
ifCond, err := newIfCond(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
matcher.ifs = append(matcher.ifs, ifCond)
}
isTrue := matcher.Match(nil)
if isTrue != test.isTrue {
t.Errorf("Test %d: expected %v found %v", i, test.isTrue, isTrue)
}
}
}
func TestSetupIfMatcher(t *testing.T) {
rex_b, _ := regexp.Compile("b")
tests := []struct {
input string
shouldErr bool
expected IfMatcher
}{
{`test {
if a match b
}`, false, IfMatcher{
ifs: []ifCond{
{a: "a", op: "match", b: "b", neg: false, rex: rex_b},
},
}},
{`test {
if a match b
if_op or
}`, false, IfMatcher{
ifs: []ifCond{
{a: "a", op: "match", b: "b", neg: false, rex: rex_b},
},
isOr: true,
}},
{`test {
if a match
}`, true, IfMatcher{},
},
{`test {
if a isn't b
}`, true, IfMatcher{},
},
{`test {
if a match b c
}`, true, IfMatcher{},
},
{`test {
if goal has go
if cook not_has go
}`, false, IfMatcher{
ifs: []ifCond{
{a: "goal", op: "has", b: "go", neg: false},
{a: "cook", op: "has", b: "go", neg: true},
},
}},
{`test {
if goal has go
if cook not_has go
if_op and
}`, false, IfMatcher{
ifs: []ifCond{
{a: "goal", op: "has", b: "go", neg: false},
{a: "cook", op: "has", b: "go", neg: true},
},
}},
{`test {
if goal has go
if cook not_has go
if_op not
}`, true, IfMatcher{},
},
}
for i, test := range tests {
c := caddy.NewTestController("http", test.input)
c.Next()
matcher, err := SetupIfMatcher(c)
if err == nil && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if err != nil && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
} else if err != nil && test.shouldErr {
continue
}
test_if, ok := matcher.(IfMatcher)
if !ok {
t.Error("RequestMatcher should be of type IfMatcher")
}
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
if len(test_if.ifs) != len(test.expected.ifs) {
t.Errorf("Test %d: Expected %d ifConditions, found %v", i,
len(test.expected.ifs), len(test_if.ifs))
}
for j, if_c := range test_if.ifs {
expected_c := test.expected.ifs[j]
if if_c.a != expected_c.a {
t.Errorf("Test %d, ifCond %d: Expected A=%s, got %s",
i, j, if_c.a, expected_c.a)
}
if if_c.op != expected_c.op {
t.Errorf("Test %d, ifCond %d: Expected Op=%s, got %s",
i, j, if_c.op, expected_c.op)
}
if if_c.b != expected_c.b {
t.Errorf("Test %d, ifCond %d: Expected B=%s, got %s",
i, j, if_c.b, expected_c.b)
}
if if_c.neg != expected_c.neg {
t.Errorf("Test %d, ifCond %d: Expected Neg=%v, got %v",
i, j, if_c.neg, expected_c.neg)
}
if expected_c.rex != nil && if_c.rex == nil {
t.Errorf("Test %d, ifCond %d: Expected Rex=%v, got <nil>",
i, j, expected_c.rex)
}
if expected_c.rex == nil && if_c.rex != nil {
t.Errorf("Test %d, ifCond %d: Expected Rex=<nil>, got %v",
i, j, if_c.rex)
}
if expected_c.rex != nil && if_c.rex != nil {
if if_c.rex.String() != expected_c.rex.String() {
t.Errorf("Test %d, ifCond %d: Expected Rex=%v, got %v",
i, j, if_c.rex, expected_c.rex)
}
}
}
}
}
func TestIfMatcherKeyword(t *testing.T) {
tests := []struct {
keyword string
expected bool
}{
{"if", true},
{"ifs", false},
{"tls", false},
{"http", false},
{"if_op", true},
{"if_type", false},
{"if_cond", false},
}
for i, test := range tests {
c := caddy.NewTestController("http", test.keyword)
c.Next()
valid := IfMatcherKeyword(c)
if valid != test.expected {
t.Errorf("Test %d: expected %v found %v", i, test.expected, valid)
}
}
}

View File

@@ -0,0 +1,70 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"fmt"
)
var (
_ error = NonHijackerError{}
_ error = NonFlusherError{}
_ error = NonCloseNotifierError{}
_ error = NonPusherError{}
)
// NonHijackerError is more descriptive error caused by a non hijacker
type NonHijackerError struct {
// underlying type which doesn't implement Hijack
Underlying interface{}
}
// Implement Error
func (h NonHijackerError) Error() string {
return fmt.Sprintf("%T is not a hijacker", h.Underlying)
}
// NonFlusherError is more descriptive error caused by a non flusher
type NonFlusherError struct {
// underlying type which doesn't implement Flush
Underlying interface{}
}
// Implement Error
func (f NonFlusherError) Error() string {
return fmt.Sprintf("%T is not a flusher", f.Underlying)
}
// NonCloseNotifierError is more descriptive error caused by a non closeNotifier
type NonCloseNotifierError struct {
// underlying type which doesn't implement CloseNotify
Underlying interface{}
}
// Implement Error
func (c NonCloseNotifierError) Error() string {
return fmt.Sprintf("%T is not a closeNotifier", c.Underlying)
}
// NonPusherError is more descriptive error caused by a non pusher
type NonPusherError struct {
// underlying type which doesn't implement pusher
Underlying interface{}
}
// Implement Error
func (c NonPusherError) Error() string {
return fmt.Sprintf("%T is not a pusher", c.Underlying)
}

View File

@@ -0,0 +1,213 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"fmt"
"net"
"net/http"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls"
)
func activateHTTPS(cctx caddy.Context) error {
operatorPresent := !caddy.Started()
if !caddy.Quiet && operatorPresent {
fmt.Print("Activating privacy features... ")
}
ctx := cctx.(*httpContext)
// pre-screen each config and earmark the ones that qualify for managed TLS
markQualifiedForAutoHTTPS(ctx.siteConfigs)
// place certificates and keys on disk
for _, c := range ctx.siteConfigs {
if c.TLS.OnDemand {
continue // obtain these certificates on-demand instead
}
err := c.TLS.ObtainCert(c.TLS.Hostname, operatorPresent)
if err != nil {
return err
}
}
// update TLS configurations
err := enableAutoHTTPS(ctx.siteConfigs, true)
if err != nil {
return err
}
// set up redirects
ctx.siteConfigs = makePlaintextRedirects(ctx.siteConfigs)
// renew all relevant certificates that need renewal. this is important
// to do right away so we guarantee that renewals aren't missed, and
// also the user can respond to any potential errors that occur.
// (skip if upgrading, because the parent process is likely already listening
// on the ports we'd need to do ACME before we finish starting; parent process
// already running renewal ticker, so renewal won't be missed anyway.)
if !caddy.IsUpgrade() {
err = caddytls.RenewManagedCertificates(true)
if err != nil {
return err
}
}
if !caddy.Quiet && operatorPresent {
fmt.Println("done.")
}
return nil
}
// markQualifiedForAutoHTTPS scans each config and, if it
// qualifies for managed TLS, it sets the Managed field of
// the TLS config to true.
func markQualifiedForAutoHTTPS(configs []*SiteConfig) {
for _, cfg := range configs {
if caddytls.QualifiesForManagedTLS(cfg) && cfg.Addr.Scheme != "http" {
cfg.TLS.Managed = true
}
}
}
// enableAutoHTTPS configures each config to use TLS according to default settings.
// It will only change configs that are marked as managed but not on-demand, and
// assumes that certificates and keys are already on disk. If loadCertificates is
// true, the certificates will be loaded from disk into the cache for this process
// to use. If false, TLS will still be enabled and configured with default settings,
// but no certificates will be parsed loaded into the cache, and the returned error
// value will always be nil.
func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
for _, cfg := range configs {
if cfg == nil || cfg.TLS == nil || !cfg.TLS.Managed || cfg.TLS.OnDemand {
continue
}
cfg.TLS.Enabled = true
cfg.Addr.Scheme = "https"
if loadCertificates && caddytls.HostQualifies(cfg.TLS.Hostname) {
_, err := cfg.TLS.CacheManagedCertificate(cfg.TLS.Hostname)
if err != nil {
return err
}
}
// Make sure any config values not explicitly set are set to default
caddytls.SetDefaultTLSParams(cfg.TLS)
// Set default port of 443 if not explicitly set
if cfg.Addr.Port == "" &&
cfg.TLS.Enabled &&
(!cfg.TLS.Manual || cfg.TLS.OnDemand) &&
cfg.Addr.Host != "localhost" {
cfg.Addr.Port = HTTPSPort
}
}
return nil
}
// makePlaintextRedirects sets up redirects from port 80 to the relevant HTTPS
// hosts. You must pass in all configs, not just configs that qualify, since
// we must know whether the same host already exists on port 80, and those would
// not be in a list of configs that qualify for automatic HTTPS. This function will
// only set up redirects for configs that qualify. It returns the updated list of
// all configs.
func makePlaintextRedirects(allConfigs []*SiteConfig) []*SiteConfig {
for i, cfg := range allConfigs {
if cfg.TLS.Managed &&
!hostHasOtherPort(allConfigs, i, HTTPPort) &&
(cfg.Addr.Port == HTTPSPort || !hostHasOtherPort(allConfigs, i, HTTPSPort)) {
allConfigs = append(allConfigs, redirPlaintextHost(cfg))
}
}
return allConfigs
}
// hostHasOtherPort returns true if there is another config in the list with the same
// hostname that has port otherPort, or false otherwise. All the configs are checked
// against the hostname of allConfigs[thisConfigIdx].
func hostHasOtherPort(allConfigs []*SiteConfig, thisConfigIdx int, otherPort string) bool {
for i, otherCfg := range allConfigs {
if i == thisConfigIdx {
continue // has to be a config OTHER than the one we're comparing against
}
if otherCfg.Addr.Host == allConfigs[thisConfigIdx].Addr.Host &&
otherCfg.Addr.Port == otherPort {
return true
}
}
return false
}
// redirPlaintextHost returns a new plaintext HTTP configuration for
// a virtualHost that simply redirects to cfg, which is assumed to
// be the HTTPS configuration. The returned configuration is set
// to listen on HTTPPort. The TLS field of cfg must not be nil.
func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
redirPort := cfg.Addr.Port
if redirPort == HTTPSPort {
// By default, HTTPSPort should be DefaultHTTPSPort,
// which of course doesn't need to be explicitly stated
// in the Location header. Even if HTTPSPort is changed
// so that it is no longer DefaultHTTPSPort, we shouldn't
// append it to the URL in the Location because changing
// the HTTPS port is assumed to be an internal-only change
// (in other words, we assume port forwarding is going on);
// but redirects go back to a presumably-external client.
// (If redirect clients are also internal, that is more
// advanced, and the user should configure HTTP->HTTPS
// redirects themselves.)
redirPort = ""
}
redirMiddleware := func(next Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
// Construct the URL to which to redirect. Note that the Host in a
// request might contain a port, but we just need the hostname from
// it; and we'll set the port if needed.
toURL := "https://"
requestHost, _, err := net.SplitHostPort(r.Host)
if err != nil {
requestHost = r.Host // Host did not contain a port, so use the whole value
}
if redirPort == "" {
toURL += requestHost
} else {
toURL += net.JoinHostPort(requestHost, redirPort)
}
toURL += r.URL.RequestURI()
w.Header().Set("Connection", "close")
http.Redirect(w, r, toURL, http.StatusMovedPermanently)
return 0, nil
})
}
host := cfg.Addr.Host
port := HTTPPort
addr := net.JoinHostPort(host, port)
return &SiteConfig{
Addr: Address{Original: addr, Host: host, Port: port},
ListenHost: cfg.ListenHost,
middleware: []Middleware{redirMiddleware},
TLS: &caddytls.Config{AltHTTPPort: cfg.TLS.AltHTTPPort, AltTLSSNIPort: cfg.TLS.AltTLSSNIPort},
Timeouts: cfg.Timeouts,
}
}

View File

@@ -0,0 +1,226 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/caddytls"
)
func TestRedirPlaintextHost(t *testing.T) {
for i, testcase := range []struct {
Host string // used for the site config
Port string
ListenHost string
RequestHost string // if different from Host
}{
{
Host: "foohost",
},
{
Host: "foohost",
Port: "80",
},
{
Host: "foohost",
Port: "1234",
},
{
Host: "foohost",
ListenHost: "93.184.216.34",
},
{
Host: "foohost",
Port: "1234",
ListenHost: "93.184.216.34",
},
{
Host: "foohost",
Port: HTTPSPort, // since this is the 'default' HTTPS port, should not be included in Location value
},
{
Host: "*.example.com",
RequestHost: "foo.example.com",
},
{
Host: "*.example.com",
Port: "1234",
RequestHost: "foo.example.com:1234",
},
} {
cfg := redirPlaintextHost(&SiteConfig{
Addr: Address{
Host: testcase.Host,
Port: testcase.Port,
},
ListenHost: testcase.ListenHost,
TLS: new(caddytls.Config),
})
// Check host and port
if actual, expected := cfg.Addr.Host, testcase.Host; actual != expected {
t.Errorf("Test %d: Expected redir config to have host %s but got %s", i, expected, actual)
}
if actual, expected := cfg.ListenHost, testcase.ListenHost; actual != expected {
t.Errorf("Test %d: Expected redir config to have bindhost %s but got %s", i, expected, actual)
}
if actual, expected := cfg.Addr.Port, HTTPPort; actual != expected {
t.Errorf("Test %d: Expected redir config to have port '%s' but got '%s'", i, expected, actual)
}
// Make sure redirect handler is set up properly
if cfg.middleware == nil || len(cfg.middleware) != 1 {
t.Fatalf("Test %d: Redir config middleware not set up properly; got: %#v", i, cfg.middleware)
}
handler := cfg.middleware[0](nil)
// Check redirect for correctness, first by inspecting error and status code
requestHost := testcase.Host // hostname of request might be different than in config (e.g. wildcards)
if testcase.RequestHost != "" {
requestHost = testcase.RequestHost
}
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "http://"+requestHost+"/bar?q=1", nil)
if err != nil {
t.Fatalf("Test %d: %v", i, err)
}
status, err := handler.ServeHTTP(rec, req)
if status != 0 {
t.Errorf("Test %d: Expected status return to be 0, but was %d", i, status)
}
if err != nil {
t.Errorf("Test %d: Expected returned error to be nil, but was %v", i, err)
}
if rec.Code != http.StatusMovedPermanently {
t.Errorf("Test %d: Expected status %d but got %d", http.StatusMovedPermanently, i, rec.Code)
}
// Now check the Location value. It should mirror the hostname and port of the request
// unless the port is redundant, in which case it should be dropped.
locationHost, _, err := net.SplitHostPort(requestHost)
if err != nil {
locationHost = requestHost
}
expectedLoc := fmt.Sprintf("https://%s/bar?q=1", locationHost)
if testcase.Port != "" && testcase.Port != DefaultHTTPSPort {
expectedLoc = fmt.Sprintf("https://%s:%s/bar?q=1", locationHost, testcase.Port)
}
if got, want := rec.Header().Get("Location"), expectedLoc; got != want {
t.Errorf("Test %d: Expected Location: '%s' but got '%s'", i, want, got)
}
}
}
func TestHostHasOtherPort(t *testing.T) {
configs := []*SiteConfig{
{Addr: Address{Host: "example.com", Port: "80"}},
{Addr: Address{Host: "sub1.example.com", Port: "80"}},
{Addr: Address{Host: "sub1.example.com", Port: "443"}},
}
if hostHasOtherPort(configs, 0, "80") {
t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`)
}
if hostHasOtherPort(configs, 0, "443") {
t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`)
}
if !hostHasOtherPort(configs, 1, "443") {
t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`)
}
}
func TestMakePlaintextRedirects(t *testing.T) {
configs := []*SiteConfig{
// Happy path = standard redirect from 80 to 443
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}},
// Host on port 80 already defined; don't change it (no redirect)
{Addr: Address{Host: "sub1.example.com", Port: "80", Scheme: "http"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "sub1.example.com"}, TLS: &caddytls.Config{Managed: true}},
// Redirect from port 80 to port 5000 in this case
{Addr: Address{Host: "sub2.example.com", Port: "5000"}, TLS: &caddytls.Config{Managed: true}},
// Can redirect from 80 to either 443 or 5001, but choose 443
{Addr: Address{Host: "sub3.example.com", Port: "443"}, TLS: &caddytls.Config{Managed: true}},
{Addr: Address{Host: "sub3.example.com", Port: "5001", Scheme: "https"}, TLS: &caddytls.Config{Managed: true}},
}
result := makePlaintextRedirects(configs)
expectedRedirCount := 3
if len(result) != len(configs)+expectedRedirCount {
t.Errorf("Expected %d redirect(s) to be added, but got %d",
expectedRedirCount, len(result)-len(configs))
}
}
func TestEnableAutoHTTPS(t *testing.T) {
configs := []*SiteConfig{
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}},
{}, // not managed - no changes!
}
enableAutoHTTPS(configs, false)
if !configs[0].TLS.Enabled {
t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false")
}
if configs[0].Addr.Scheme != "https" {
t.Errorf("Expected config 0 to have Addr.Scheme == \"https\", but it was \"%s\"",
configs[0].Addr.Scheme)
}
if configs[1].TLS != nil && configs[1].TLS.Enabled {
t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true")
}
}
func TestMarkQualifiedForAutoHTTPS(t *testing.T) {
// TODO: caddytls.TestQualifiesForManagedTLS and this test share nearly the same config list...
configs := []*SiteConfig{
{Addr: Address{Host: ""}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "localhost"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "123.44.3.21"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "example.com"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Manual: true}},
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "off"}},
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "foo@bar.com"}},
{Addr: Address{Host: "example.com", Scheme: "http"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "example.com", Port: "80"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "example.com", Port: "1234"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "example.com", Scheme: "https"}, TLS: new(caddytls.Config)},
{Addr: Address{Host: "example.com", Port: "80", Scheme: "https"}, TLS: new(caddytls.Config)},
}
expectedManagedCount := 4
markQualifiedForAutoHTTPS(configs)
count := 0
for _, cfg := range configs {
if cfg.TLS.Managed {
count++
}
}
if count != expectedManagedCount {
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
}
}

View File

@@ -0,0 +1,196 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bytes"
"io"
"log"
"net"
"os"
"strings"
"sync"
"github.com/hashicorp/go-syslog"
"github.com/mholt/caddy"
)
var remoteSyslogPrefixes = map[string]string{
"syslog+tcp://": "tcp",
"syslog+udp://": "udp",
"syslog://": "udp",
}
// Logger is shared between errors and log plugins and supports both logging to
// a file (with an optional file roller), local and remote syslog servers.
type Logger struct {
Output string
*log.Logger
Roller *LogRoller
writer io.Writer
fileMu *sync.RWMutex
V4ipMask net.IPMask
V6ipMask net.IPMask
IPMaskExists bool
Exceptions []string
}
// NewTestLogger creates logger suitable for testing purposes
func NewTestLogger(buffer *bytes.Buffer) *Logger {
return &Logger{
Logger: log.New(buffer, "", 0),
fileMu: new(sync.RWMutex),
}
}
// Println wraps underlying logger with mutex
func (l Logger) Println(args ...interface{}) {
l.fileMu.RLock()
l.Logger.Println(args...)
l.fileMu.RUnlock()
}
// Printf wraps underlying logger with mutex
func (l Logger) Printf(format string, args ...interface{}) {
l.fileMu.RLock()
l.Logger.Printf(format, args...)
l.fileMu.RUnlock()
}
func (l Logger) MaskIP(ip string) string {
var reqIP net.IP
// If unable to parse, simply return IP as provided.
reqIP = net.ParseIP(ip)
if reqIP == nil {
return ip
}
if reqIP.To4() != nil {
return reqIP.Mask(l.V4ipMask).String()
} else {
return reqIP.Mask(l.V6ipMask).String()
}
}
// ShouldLog returns true if the path is not exempted from
// being logged (i.e. it is not found in l.Exceptions).
func (l Logger) ShouldLog(path string) bool {
for _, exc := range l.Exceptions {
if Path(path).Matches(exc) {
return false
}
}
return true
}
// Attach binds logger Start and Close functions to
// controller's OnStartup and OnShutdown hooks.
func (l *Logger) Attach(controller *caddy.Controller) {
if controller != nil {
// Opens file or connect to local/remote syslog
controller.OnStartup(l.Start)
// Closes file or disconnects from local/remote syslog
controller.OnShutdown(l.Close)
}
}
type syslogAddress struct {
network string
address string
}
func parseSyslogAddress(location string) *syslogAddress {
for prefix, network := range remoteSyslogPrefixes {
if strings.HasPrefix(location, prefix) {
return &syslogAddress{
network: network,
address: strings.TrimPrefix(location, prefix),
}
}
}
return nil
}
// Start initializes logger opening files or local/remote syslog connections
func (l *Logger) Start() error {
// initialize mutex on start
l.fileMu = new(sync.RWMutex)
var err error
selectwriter:
switch l.Output {
case "", "stderr":
l.writer = os.Stderr
case "stdout":
l.writer = os.Stdout
case "syslog":
l.writer, err = gsyslog.NewLogger(gsyslog.LOG_ERR, "LOCAL0", "caddy")
if err != nil {
return err
}
default:
if address := parseSyslogAddress(l.Output); address != nil {
l.writer, err = gsyslog.DialLogger(address.network, address.address, gsyslog.LOG_ERR, "LOCAL0", "caddy")
if err != nil {
return err
}
break selectwriter
}
var file *os.File
file, err = os.OpenFile(l.Output, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
return err
}
if l.Roller != nil {
file.Close()
l.Roller.Filename = l.Output
l.writer = l.Roller.GetLogWriter()
} else {
l.writer = file
}
}
l.Logger = log.New(l.writer, "", 0)
return nil
}
// Close closes open log files or connections to syslog.
func (l *Logger) Close() error {
// don't close stdout or stderr
if l.writer == os.Stdout || l.writer == os.Stderr {
return nil
}
// Will close local/remote syslog connections too :)
if closer, ok := l.writer.(io.WriteCloser); ok {
l.fileMu.Lock()
err := closer.Close()
l.fileMu.Unlock()
return err
}
return nil
}

View File

@@ -0,0 +1,226 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//+build linux darwin
package httpserver
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
"testing"
syslog "gopkg.in/mcuadros/go-syslog.v2"
"gopkg.in/mcuadros/go-syslog.v2/format"
)
func TestLoggingToStdout(t *testing.T) {
testCases := []struct {
Output string
ExpectedOutput string
}{
{
Output: "stdout",
ExpectedOutput: "Hello world logged to stdout",
},
}
for i, testCase := range testCases {
output := captureStdout(func() {
logger := Logger{Output: testCase.Output, fileMu: new(sync.RWMutex)}
if err := logger.Start(); err != nil {
t.Fatalf("Got unexpected error: %v", err)
}
logger.Println(testCase.ExpectedOutput)
})
if !strings.Contains(output, testCase.ExpectedOutput) {
t.Fatalf("Test #%d: Expected output to contain: %s, got: %s", i, testCase.ExpectedOutput, output)
}
}
}
func TestLoggingToStderr(t *testing.T) {
testCases := []struct {
Output string
ExpectedOutput string
}{
{
Output: "stderr",
ExpectedOutput: "Hello world logged to stderr",
},
{
Output: "",
ExpectedOutput: "Hello world logged to stderr #2",
},
}
for i, testCase := range testCases {
output := captureStderr(func() {
logger := Logger{Output: testCase.Output, fileMu: new(sync.RWMutex)}
if err := logger.Start(); err != nil {
t.Fatalf("Got unexpected error: %v", err)
}
logger.Println(testCase.ExpectedOutput)
})
if !strings.Contains(output, testCase.ExpectedOutput) {
t.Fatalf("Test #%d: Expected output to contain: %s, got: %s", i, testCase.ExpectedOutput, output)
}
}
}
func TestLoggingToFile(t *testing.T) {
file := filepath.Join(os.TempDir(), "access.log")
expectedOutput := "Hello world written to file"
logger := Logger{Output: file}
if err := logger.Start(); err != nil {
t.Fatalf("Got unexpected error during logger start: %v", err)
}
logger.Print(expectedOutput)
content, err := ioutil.ReadFile(file)
if err != nil {
t.Fatalf("Could not read log file content: %v", err)
}
if !bytes.Contains(content, []byte(expectedOutput)) {
t.Fatalf("Expected log file to contain: %s, got: %s", expectedOutput, string(content))
}
os.Remove(file)
}
func TestLoggingToSyslog(t *testing.T) {
testCases := []struct {
Output string
ExpectedOutput string
}{
{
Output: "syslog://127.0.0.1:5660",
ExpectedOutput: "Hello world! Test #1 over tcp",
},
{
Output: "syslog+tcp://127.0.0.1:5661",
ExpectedOutput: "Hello world! Test #2 over tcp",
},
{
Output: "syslog+udp://127.0.0.1:5662",
ExpectedOutput: "Hello world! Test #3 over udp",
},
}
for i, testCase := range testCases {
ch := make(chan format.LogParts, 256)
server, err := bootServer(testCase.Output, ch)
defer server.Kill()
if err != nil {
t.Errorf("Test #%d: expected no error during syslog server boot, got: %v", i, err)
}
logger := Logger{Output: testCase.Output, fileMu: new(sync.RWMutex)}
if err := logger.Start(); err != nil {
t.Errorf("Test #%d: expected no error during logger start, got: %v", i, err)
}
defer logger.Close()
logger.Print(testCase.ExpectedOutput)
actual := <-ch
if content, ok := actual["content"].(string); ok {
if !strings.Contains(content, testCase.ExpectedOutput) {
t.Errorf("Test #%d: expected server to capture content: %s, but got: %s", i, testCase.ExpectedOutput, content)
}
} else {
t.Errorf("Test #%d: expected server to capture content but got: %v", i, actual)
}
}
}
func bootServer(location string, ch chan format.LogParts) (*syslog.Server, error) {
address := parseSyslogAddress(location)
if address == nil {
return nil, fmt.Errorf("Could not parse syslog address: %s", location)
}
server := syslog.NewServer()
server.SetFormat(syslog.Automatic)
switch address.network {
case "tcp":
server.ListenTCP(address.address)
case "udp":
server.ListenUDP(address.address)
}
server.SetHandler(syslog.NewChannelHandler(ch))
if err := server.Boot(); err != nil {
return nil, err
}
return server, nil
}
func captureStdout(f func()) string {
original := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
f()
w.Close()
written, _ := ioutil.ReadAll(r)
os.Stdout = original
return string(written)
}
func captureStderr(f func()) string {
original := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
f()
w.Close()
written, _ := ioutil.ReadAll(r)
os.Stderr = original
return string(written)
}

View File

@@ -0,0 +1,228 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"fmt"
"net/http"
"os"
"path"
"time"
"github.com/mholt/caddy"
)
func init() {
initCaseSettings()
}
type (
// Middleware is the middle layer which represents the traditional
// idea of middleware: it chains one Handler to the next by being
// passed the next Handler in the chain.
Middleware func(Handler) Handler
// ListenerMiddleware is similar to the Middleware type, except it
// chains one net.Listener to the next.
ListenerMiddleware func(caddy.Listener) caddy.Listener
// Handler is like http.Handler except ServeHTTP may return a status
// code and/or error.
//
// If ServeHTTP writes the response header, it should return a status
// code of 0. This signals to other handlers before it that the response
// is already handled, and that they should not write to it also. Keep
// in mind that writing to the response body writes the header, too.
//
// If ServeHTTP encounters an error, it should return the error value
// so it can be logged by designated error-handling middleware.
//
// If writing a response after calling the next ServeHTTP method, the
// returned status code SHOULD be used when writing the response.
//
// If handling errors after calling the next ServeHTTP method, the
// returned error value SHOULD be logged or handled accordingly.
//
// Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged.
Handler interface {
ServeHTTP(http.ResponseWriter, *http.Request) (int, error)
}
// HandlerFunc is a convenience type like http.HandlerFunc, except
// ServeHTTP returns a status code and an error. See Handler
// documentation for more information.
HandlerFunc func(http.ResponseWriter, *http.Request) (int, error)
// RequestMatcher checks to see if current request should be handled
// by underlying handler.
RequestMatcher interface {
Match(r *http.Request) bool
}
// HandlerConfig is a middleware configuration.
// This makes it possible for middlewares to have a common
// configuration interface.
//
// TODO The long term plan is to get all middleware implement this
// interface for configurations.
HandlerConfig interface {
RequestMatcher
BasePath() string
}
// ConfigSelector selects a configuration.
ConfigSelector []HandlerConfig
)
// ServeHTTP implements the Handler interface.
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return f(w, r)
}
// Select selects a Config.
// This chooses the config with the longest length.
func (c ConfigSelector) Select(r *http.Request) (config HandlerConfig) {
for i := range c {
if !c[i].Match(r) {
continue
}
if config == nil || len(c[i].BasePath()) > len(config.BasePath()) {
config = c[i]
}
}
return config
}
// IndexFile looks for a file in /root/fpath/indexFile for each string
// in indexFiles. If an index file is found, it returns the root-relative
// path to the file and true. If no index file is found, empty string
// and false is returned. fpath must end in a forward slash '/'
// otherwise no index files will be tried (directory paths must end
// in a forward slash according to HTTP).
//
// All paths passed into and returned from this function use '/' as the
// path separator, just like URLs. IndexFle handles path manipulation
// internally for systems that use different path separators.
func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) {
if fpath[len(fpath)-1] != '/' || root == nil {
return "", false
}
for _, indexFile := range indexFiles {
// func (http.FileSystem).Open wants all paths separated by "/",
// regardless of operating system convention, so use
// path.Join instead of filepath.Join
fp := path.Join(fpath, indexFile)
f, err := root.Open(fp)
if err == nil {
f.Close()
return fp, true
}
}
return "", false
}
// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it
// as a Last-Modified header to the ResponseWriter. If the modTime is in the future
// the current time is used instead.
func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) {
if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) {
// the time does not appear to be valid. Don't put it in the response
return
}
// RFC 2616 - Section 14.29 - Last-Modified:
// An origin server MUST NOT send a Last-Modified date which is later than the
// server's time of message origination. In such cases, where the resource's last
// modification would indicate some time in the future, the server MUST replace
// that date with the message origination date.
now := currentTime()
if modTime.After(now) {
modTime = now
}
w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat))
}
// CaseSensitivePath determines if paths should be case sensitive.
// This is configurable via CASE_SENSITIVE_PATH environment variable.
var CaseSensitivePath = false
const caseSensitivePathEnv = "CASE_SENSITIVE_PATH"
// initCaseSettings loads case sensitivity config from environment variable.
//
// This could have been in init, but init cannot be called from tests.
func initCaseSettings() {
switch os.Getenv(caseSensitivePathEnv) {
case "1", "true":
CaseSensitivePath = true
default:
CaseSensitivePath = false
}
}
// MergeRequestMatchers merges multiple RequestMatchers into one.
// This allows a middleware to use multiple RequestMatchers.
func MergeRequestMatchers(matchers ...RequestMatcher) RequestMatcher {
return requestMatchers(matchers)
}
type requestMatchers []RequestMatcher
// Match satisfies RequestMatcher interface.
func (m requestMatchers) Match(r *http.Request) bool {
for _, matcher := range m {
if !matcher.Match(r) {
return false
}
}
return true
}
// currentTime, as it is defined here, returns time.Now().
// It's defined as a variable for mocking time in tests.
var currentTime = func() time.Time { return time.Now() }
// EmptyNext is a no-op function that can be passed into
// Middleware functions so that the assignment to the
// Next field of the Handler can be tested.
//
// Used primarily for testing but needs to be exported so
// plugins can use this as a convenience.
var EmptyNext = HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return 0, nil })
// SameNext does a pointer comparison between next1 and next2.
//
// Used primarily for testing but needs to be exported so
// plugins can use this as a convenience.
func SameNext(next1, next2 Handler) bool {
return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2)
}
// Context key constants.
const (
// ReplacerCtxKey is the context key for a per-request replacer.
ReplacerCtxKey caddy.CtxKey = "replacer"
// RemoteUserCtxKey is the key for the remote user of the request, if any (basicauth).
RemoteUserCtxKey caddy.CtxKey = "remote_user"
// MitmCtxKey is the key for the result of MITM detection
MitmCtxKey caddy.CtxKey = "mitm"
// RequestIDCtxKey is the key for the U4 UUID value
RequestIDCtxKey caddy.CtxKey = "request_id"
)

View File

@@ -0,0 +1,72 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"os"
"testing"
)
func TestPathCaseSensitivity(t *testing.T) {
tests := []struct {
basePath string
path string
caseSensitive bool
expected bool
}{
{"/", "/file", true, true},
{"/a", "/file", true, false},
{"/f", "/file", true, true},
{"/f", "/File", true, false},
{"/f", "/File", false, true},
{"/file", "/file", true, true},
{"/file", "/file", false, true},
{"/files", "/file", false, false},
{"/files", "/file", true, false},
{"/folder", "/folder/file.txt", true, true},
{"/folders", "/folder/file.txt", true, false},
{"/folder", "/Folder/file.txt", false, true},
{"/folders", "/Folder/file.txt", false, false},
}
for i, test := range tests {
CaseSensitivePath = test.caseSensitive
valid := Path(test.path).Matches(test.basePath)
if test.expected != valid {
t.Errorf("Test %d: Expected %v, found %v", i, test.expected, valid)
}
}
}
func TestPathCaseSensitiveEnv(t *testing.T) {
tests := []struct {
envValue string
expected bool
}{
{"1", true},
{"0", false},
{"false", false},
{"true", true},
{"", false},
}
for i, test := range tests {
os.Setenv(caseSensitivePathEnv, test.envValue)
initCaseSettings()
if test.expected != CaseSensitivePath {
t.Errorf("Test %d: Expected %v, found %v", i, test.expected, CaseSensitivePath)
}
}
}

View File

@@ -0,0 +1,779 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry"
)
// tlsHandler is a http.Handler that will inject a value
// into the request context indicating if the TLS
// connection is likely being intercepted.
type tlsHandler struct {
next http.Handler
listener *tlsHelloListener
closeOnMITM bool // whether to close connection on MITM; TODO: expose through new directive
}
// ServeHTTP checks the User-Agent. For the four main browsers (Chrome,
// Edge, Firefox, and Safari) indicated by the User-Agent, the properties
// of the TLS Client Hello will be compared. The context value "mitm" will
// be set to a value indicating if it is likely that the underlying TLS
// connection is being intercepted.
//
// Note that due to Microsoft's decision to intentionally make IE/Edge
// user agents obscure (and look like other browsers), this may offer
// less accuracy for IE/Edge clients.
//
// This MITM detection capability is based on research done by Durumeric,
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// TODO: one request per connection, we should report UA in connection with
// handshake (reported in caddytls package) and our MITM assessment
if h.listener == nil {
h.next.ServeHTTP(w, r)
return
}
h.listener.helloInfosMu.RLock()
info := h.listener.helloInfos[r.RemoteAddr]
h.listener.helloInfosMu.RUnlock()
ua := r.Header.Get("User-Agent")
uaHash := telemetry.FastHash([]byte(ua))
// report this request's UA in connection with this ClientHello
go telemetry.AppendUnique("tls_client_hello_ua:"+caddytls.ClientHelloInfo(info).Key(), uaHash)
var checked, mitm bool
if r.Header.Get("X-BlueCoat-Via") != "" || // Blue Coat (masks User-Agent header to generic values)
r.Header.Get("X-FCCKV2") != "" || // Fortinet
info.advertisesHeartbeatSupport() { // no major browsers have ever implemented Heartbeat
checked = true
mitm = true
} else if strings.Contains(ua, "Edge") || strings.Contains(ua, "MSIE") ||
strings.Contains(ua, "Trident") {
checked = true
mitm = !info.looksLikeEdge()
} else if strings.Contains(ua, "Chrome") {
checked = true
mitm = !info.looksLikeChrome()
} else if strings.Contains(ua, "CriOS") {
// Chrome on iOS sometimes uses iOS-provided TLS stack (which looks exactly like Safari)
// but for connections that don't render a web page (favicon, etc.) it uses its own...
checked = true
mitm = !info.looksLikeChrome() && !info.looksLikeSafari()
} else if strings.Contains(ua, "Firefox") {
checked = true
if strings.Contains(ua, "Windows") {
ver := getVersion(ua, "Firefox")
if ver == 45.0 || ver == 52.0 {
mitm = !info.looksLikeTor()
} else {
mitm = !info.looksLikeFirefox()
}
} else {
mitm = !info.looksLikeFirefox()
}
} else if strings.Contains(ua, "Safari") {
checked = true
mitm = !info.looksLikeSafari()
}
if checked {
r = r.WithContext(context.WithValue(r.Context(), MitmCtxKey, mitm))
if mitm {
go telemetry.AppendUnique("http_mitm", "likely")
} else {
go telemetry.AppendUnique("http_mitm", "unlikely")
}
} else {
go telemetry.AppendUnique("http_mitm", "unknown")
}
if mitm && h.closeOnMITM {
// TODO: This termination might need to happen later in the middleware
// chain in order to be picked up by the log directive, in case the site
// owner still wants to log this event. It'll probably require a new
// directive. If this feature is useful, we can finish implementing this.
r.Close = true
return
}
h.next.ServeHTTP(w, r)
}
// getVersion returns a (possibly simplified) representation of the version string
// from a UserAgent string. It returns a float, so it can represent major and minor
// versions; the rest of the version is just tacked on behind the decimal point.
// The purpose of this is to stay simple while allowing for basic, fast comparisons.
// If the version for softwareName is not found in ua, -1 is returned.
func getVersion(ua, softwareName string) float64 {
search := softwareName + "/"
start := strings.Index(ua, search)
if start < 0 {
return -1
}
start += len(search)
end := strings.Index(ua[start:], " ")
if end < 0 {
end = len(ua)
} else {
end += start
}
strVer := strings.Replace(ua[start:end], "-", "", -1)
firstDot := strings.Index(strVer, ".")
if firstDot >= 0 {
strVer = strVer[:firstDot+1] + strings.Replace(strVer[firstDot+1:], ".", "", -1)
}
ver, err := strconv.ParseFloat(strVer, 64)
if err != nil {
return -1
}
return ver
}
// clientHelloConn reads the ClientHello
// and stores it in the attached listener.
type clientHelloConn struct {
net.Conn
listener *tlsHelloListener
readHello bool // whether ClientHello has been read
buf *bytes.Buffer
}
// Read reads from c.Conn (by letting the standard library
// do the reading off the wire), with the exception of
// getting a copy of the ClientHello so it can parse it.
func (c *clientHelloConn) Read(b []byte) (n int, err error) {
// if we've already read the ClientHello, pass thru
if c.readHello {
return c.Conn.Read(b)
}
// we let the standard lib read off the wire for us, and
// tee that into our buffer so we can read the ClientHello
tee := io.TeeReader(c.Conn, c.buf)
n, err = tee.Read(b)
if err != nil {
return
}
if c.buf.Len() < 5 {
return // need to read more bytes for header
}
// read the header bytes
hdr := make([]byte, 5)
_, err = io.ReadFull(c.buf, hdr)
if err != nil {
return // this would be highly unusual and sad
}
// get length of the ClientHello message and read it
length := int(uint16(hdr[3])<<8 | uint16(hdr[4]))
if c.buf.Len() < length {
return // need to read more bytes
}
hello := make([]byte, length)
_, err = io.ReadFull(c.buf, hello)
if err != nil {
return
}
bufpool.Put(c.buf) // buffer no longer needed
// parse the ClientHello and store it in the map
rawParsed := parseRawClientHello(hello)
c.listener.helloInfosMu.Lock()
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
c.listener.helloInfosMu.Unlock()
// report this ClientHello to telemetry
chKey := caddytls.ClientHelloInfo(rawParsed).Key()
go telemetry.SetNested("tls_client_hello", chKey, rawParsed)
go telemetry.AppendUnique("tls_client_hello_count", chKey)
c.readHello = true
return
}
// parseRawClientHello parses data which contains the raw
// TLS Client Hello message. It extracts relevant information
// into info. Any error reading the Client Hello (such as
// insufficient length or invalid length values) results in
// a silent error and an incomplete info struct, since there
// is no good way to handle an error like this during Accept().
// The data is expected to contain the whole ClientHello and
// ONLY the ClientHello.
//
// The majority of this code is borrowed from the Go standard
// library, which is (c) The Go Authors. It has been modified
// to fit this use case.
func parseRawClientHello(data []byte) (info rawHelloInfo) {
if len(data) < 42 {
return
}
info.Version = uint16(data[4])<<8 | uint16(data[5])
sessionIDLen := int(data[38])
if sessionIDLen > 32 || len(data) < 39+sessionIDLen {
return
}
data = data[39+sessionIDLen:]
if len(data) < 2 {
return
}
// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
// they are uint16s, the number must be even.
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
return
}
numCipherSuites := cipherSuiteLen / 2
// read in the cipher suites
info.CipherSuites = make([]uint16, numCipherSuites)
for i := 0; i < numCipherSuites; i++ {
info.CipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
return
}
// read in the compression methods
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
return
}
info.CompressionMethods = data[1 : 1+compressionMethodsLen]
data = data[1+compressionMethodsLen:]
// ClientHello is optionally followed by extension data
if len(data) < 2 {
return
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
return
}
// read in each extension, and extract any relevant information
// from extensions we care about
for len(data) != 0 {
if len(data) < 4 {
return
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return
}
// record that the client advertised support for this extension
info.Extensions = append(info.Extensions, extension)
switch extension {
case extensionSupportedCurves:
// http://tools.ietf.org/html/rfc4492#section-5.5.1
if length < 2 {
return
}
l := int(data[0])<<8 | int(data[1])
if l%2 == 1 || length != l+2 {
return
}
numCurves := l / 2
info.Curves = make([]tls.CurveID, numCurves)
d := data[2:]
for i := 0; i < numCurves; i++ {
info.Curves[i] = tls.CurveID(d[0])<<8 | tls.CurveID(d[1])
d = d[2:]
}
case extensionSupportedPoints:
// http://tools.ietf.org/html/rfc4492#section-5.5.2
if length < 1 {
return
}
l := int(data[0])
if length != l+1 {
return
}
info.Points = make([]uint8, l)
copy(info.Points, data[1:])
}
data = data[length:]
}
return
}
// newTLSListener returns a new tlsHelloListener that wraps ln.
func newTLSListener(ln net.Listener, config *tls.Config) *tlsHelloListener {
return &tlsHelloListener{
Listener: ln,
config: config,
helloInfos: make(map[string]rawHelloInfo),
}
}
// tlsHelloListener is a TLS listener that is specially designed
// to read the ClientHello manually so we can extract necessary
// information from it. Each ClientHello message is mapped by
// the remote address of the client, which must be removed when
// the connection is closed (use ConnState).
type tlsHelloListener struct {
net.Listener
config *tls.Config
helloInfos map[string]rawHelloInfo
helloInfosMu sync.RWMutex
}
// Accept waits for and returns the next connection to the listener.
// After it accepts the underlying connection, it reads the
// ClientHello message and stores the parsed data into a map on l.
func (l *tlsHelloListener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
buf := bufpool.Get().(*bytes.Buffer)
buf.Reset()
helloConn := &clientHelloConn{Conn: conn, listener: l, buf: buf}
return tls.Server(helloConn, l.config), nil
}
// rawHelloInfo contains the "raw" data parsed from the TLS
// Client Hello. No interpretation is done on the raw data.
//
// The methods on this type implement heuristics described
// by Durumeric, Halderman, et. al. in
// "The Security Impact of HTTPS Interception":
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
type rawHelloInfo caddytls.ClientHelloInfo
// advertisesHeartbeatSupport returns true if info indicates
// that the client supports the Heartbeat extension.
func (info rawHelloInfo) advertisesHeartbeatSupport() bool {
for _, ext := range info.Extensions {
if ext == extensionHeartbeat {
return true
}
}
return false
}
// looksLikeFirefox returns true if info looks like a handshake
// from a modern version of Firefox.
func (info rawHelloInfo) looksLikeFirefox() bool {
// "To determine whether a Firefox session has been
// intercepted, we check for the presence and order
// of extensions, cipher suites, elliptic curves,
// EC point formats, and handshake compression methods." (early 2016)
// We check for the presence and order of the extensions.
// Note: Sometimes 0x15 (21, padding) is present, sometimes not.
// Note: Firefox 51+ does not advertise 0x3374 (13172, NPN).
// Note: Firefox doesn't advertise 0x0 (0, SNI) when connecting to IP addresses.
// Note: Firefox 55+ doesn't appear to advertise 0xFF03 (65283, short headers). It used to be between 5 and 13.
// Note: Firefox on Fedora (or RedHat) doesn't include ECC suites because of patent liability.
requiredExtensionsOrder := []uint16{23, 65281, 10, 11, 35, 16, 5, 13}
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.Extensions, true) {
return false
}
// We check for both presence of curves and their ordering.
requiredCurves := []tls.CurveID{29, 23, 24, 25}
if len(info.Curves) < len(requiredCurves) {
return false
}
for i := range requiredCurves {
if info.Curves[i] != requiredCurves[i] {
return false
}
}
if len(info.Curves) > len(requiredCurves) {
// newer Firefox (55 Nightly?) may have additional curves at end of list
allowedCurves := []tls.CurveID{256, 257}
for i := range allowedCurves {
if info.Curves[len(requiredCurves)+i] != allowedCurves[i] {
return false
}
}
}
if hasGreaseCiphers(info.CipherSuites) {
return false
}
// We check for order of cipher suites but not presence, since
// according to the paper, cipher suites may be not be added
// or reordered by the user, but they may be disabled.
expectedCipherSuiteOrder := []uint16{
TLS_AES_128_GCM_SHA256, // 0x1301
TLS_CHACHA20_POLY1305_SHA256, // 0x1303
TLS_AES_256_GCM_SHA384, // 0x1302
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, // 0xc02b
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, // 0xc02f
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // 0xcca9
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // 0xcca8
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, // 0xc02c
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, // 0xc030
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, // 0xc00a
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, // 0xc009
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, // 0xc013
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, // 0xc014
TLS_DHE_RSA_WITH_AES_128_CBC_SHA, // 0x33
TLS_DHE_RSA_WITH_AES_256_CBC_SHA, // 0x39
tls.TLS_RSA_WITH_AES_128_CBC_SHA, // 0x2f
tls.TLS_RSA_WITH_AES_256_CBC_SHA, // 0x35
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // 0xa
}
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.CipherSuites, false)
}
// looksLikeChrome returns true if info looks like a handshake
// from a modern version of Chrome.
func (info rawHelloInfo) looksLikeChrome() bool {
// "We check for ciphers and extensions that Chrome is known
// to not support, but do not check for the inclusion of
// specific ciphers or extensions, nor do we validate their
// order. When appropriate, we check the presence and order
// of elliptic curves, compression methods, and EC point formats." (early 2016)
// Not in Chrome 56, but present in Safari 10 (Feb. 2017):
// TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 (0xc024)
// TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 (0xc023)
// TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA (0xc00a)
// TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA (0xc009)
// TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 (0xc028)
// TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 (0xc027)
// TLS_RSA_WITH_AES_256_CBC_SHA256 (0x3d)
// TLS_RSA_WITH_AES_128_CBC_SHA256 (0x3c)
// Not in Chrome 56, but present in Firefox 51 (Feb. 2017):
// TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA (0xc00a)
// TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA (0xc009)
// TLS_DHE_RSA_WITH_AES_128_CBC_SHA (0x33)
// TLS_DHE_RSA_WITH_AES_256_CBC_SHA (0x39)
// Selected ciphers present in Chrome mobile (Feb. 2017):
// 0xc00a, 0xc014, 0xc009, 0x9c, 0x9d, 0x2f, 0x35, 0xa
chromeCipherExclusions := map[uint16]struct{}{
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: {}, // 0xc024
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: {}, // 0xc023
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384: {}, // 0xc028
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: {}, // 0xc027
TLS_RSA_WITH_AES_256_CBC_SHA256: {}, // 0x3d
tls.TLS_RSA_WITH_AES_128_CBC_SHA256: {}, // 0x3c
TLS_DHE_RSA_WITH_AES_128_CBC_SHA: {}, // 0x33
TLS_DHE_RSA_WITH_AES_256_CBC_SHA: {}, // 0x39
}
for _, ext := range info.CipherSuites {
if _, ok := chromeCipherExclusions[ext]; ok {
return false
}
}
// Chrome does not include curve 25 (CurveP521) (as of Chrome 56, Feb. 2017).
for _, curve := range info.Curves {
if curve == 25 {
return false
}
}
if !hasGreaseCiphers(info.CipherSuites) {
return false
}
return true
}
// looksLikeEdge returns true if info looks like a handshake
// from a modern version of MS Edge.
func (info rawHelloInfo) looksLikeEdge() bool {
// "SChannel connections can by uniquely identified because SChannel
// is the only TLS library we tested that includes the OCSP status
// request extension before the supported groups and EC point formats
// extensions." (early 2016)
//
// More specifically, the OCSP status request extension appears
// *directly* before the other two extensions, which occur in that
// order. (I contacted the authors for clarification and verified it.)
for i, ext := range info.Extensions {
if ext == extensionOCSPStatusRequest {
if len(info.Extensions) <= i+2 {
return false
}
if info.Extensions[i+1] != extensionSupportedCurves ||
info.Extensions[i+2] != extensionSupportedPoints {
return false
}
}
}
for _, cs := range info.CipherSuites {
// As of Feb. 2017, Edge does not have 0xff, but Avast adds it
if cs == scsvRenegotiation {
return false
}
// Edge and modern IE do not have 0x4 or 0x5, but Blue Coat does
if cs == TLS_RSA_WITH_RC4_128_MD5 || cs == tls.TLS_RSA_WITH_RC4_128_SHA {
return false
}
}
if hasGreaseCiphers(info.CipherSuites) {
return false
}
return true
}
// looksLikeSafari returns true if info looks like a handshake
// from a modern version of MS Safari.
func (info rawHelloInfo) looksLikeSafari() bool {
// "One unique aspect of Secure Transport is that it includes
// the TLS_EMPTY_RENEGOTIATION_INFO_SCSV (0xff) cipher first,
// whereas the other libraries we investigated include the
// cipher last. Similar to Microsoft, Apple has changed
// TLS behavior in minor OS updates, which are not indicated
// in the HTTP User-Agent header. We allow for any of the
// updates when validating handshakes, and we check for the
// presence and ordering of ciphers, extensions, elliptic
// curves, and compression methods." (early 2016)
// Note that any C lib (e.g. curl) compiled on macOS
// will probably use Secure Transport which will also
// share the TLS handshake characteristics of Safari.
// We check for the presence and order of the extensions.
requiredExtensionsOrder := []uint16{10, 11, 13, 13172, 16, 5, 18, 23}
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.Extensions, true) {
// Safari on iOS 11 (beta) uses different set/ordering of extensions
requiredExtensionsOrderiOS11 := []uint16{65281, 0, 23, 13, 5, 13172, 18, 16, 11, 10}
if !assertPresenceAndOrdering(requiredExtensionsOrderiOS11, info.Extensions, true) {
return false
}
} else {
// For these versions of Safari, expect TLS_EMPTY_RENEGOTIATION_INFO_SCSV first.
if len(info.CipherSuites) < 1 {
return false
}
if info.CipherSuites[0] != scsvRenegotiation {
return false
}
}
if hasGreaseCiphers(info.CipherSuites) {
return false
}
// We check for order and presence of cipher suites
expectedCipherSuiteOrder := []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, // 0xc02c
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, // 0xc02b
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, // 0xc024
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, // 0xc023
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, // 0xc00a
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, // 0xc009
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, // 0xc030
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, // 0xc02f
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, // 0xc028
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, // 0xc027
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, // 0xc014
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, // 0xc013
tls.TLS_RSA_WITH_AES_256_GCM_SHA384, // 0x9d
tls.TLS_RSA_WITH_AES_128_GCM_SHA256, // 0x9c
TLS_RSA_WITH_AES_256_CBC_SHA256, // 0x3d
tls.TLS_RSA_WITH_AES_128_CBC_SHA256, // 0x3c
tls.TLS_RSA_WITH_AES_256_CBC_SHA, // 0x35
tls.TLS_RSA_WITH_AES_128_CBC_SHA, // 0x2f
}
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.CipherSuites, true)
}
// looksLikeTor returns true if the info looks like a ClientHello from Tor browser
// (based on Firefox).
func (info rawHelloInfo) looksLikeTor() bool {
requiredExtensionsOrder := []uint16{10, 11, 16, 5, 13}
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.Extensions, true) {
return false
}
// check for session tickets support; Tor doesn't support them to prevent tracking
for _, ext := range info.Extensions {
if ext == 35 {
return false
}
}
// We check for both presence of curves and their ordering, including
// an optional curve at the beginning (for Tor based on Firefox 52)
infoCurves := info.Curves
if len(info.Curves) == 4 {
if info.Curves[0] != 29 {
return false
}
infoCurves = info.Curves[1:]
}
requiredCurves := []tls.CurveID{23, 24, 25}
if len(infoCurves) < len(requiredCurves) {
return false
}
for i := range requiredCurves {
if infoCurves[i] != requiredCurves[i] {
return false
}
}
if hasGreaseCiphers(info.CipherSuites) {
return false
}
// We check for order of cipher suites but not presence, since
// according to the paper, cipher suites may be not be added
// or reordered by the user, but they may be disabled.
expectedCipherSuiteOrder := []uint16{
TLS_AES_128_GCM_SHA256, // 0x1301
TLS_CHACHA20_POLY1305_SHA256, // 0x1303
TLS_AES_256_GCM_SHA384, // 0x1302
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, // 0xc02b
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, // 0xc02f
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // 0xcca9
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // 0xcca8
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, // 0xc02c
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, // 0xc030
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, // 0xc00a
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, // 0xc009
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, // 0xc013
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, // 0xc014
TLS_DHE_RSA_WITH_AES_128_CBC_SHA, // 0x33
TLS_DHE_RSA_WITH_AES_256_CBC_SHA, // 0x39
tls.TLS_RSA_WITH_AES_128_CBC_SHA, // 0x2f
tls.TLS_RSA_WITH_AES_256_CBC_SHA, // 0x35
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // 0xa
}
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.CipherSuites, false)
}
// assertPresenceAndOrdering will return true if candidateList contains
// the items in requiredItems in the same order as requiredItems.
//
// If requiredIsSubset is true, then all items in requiredItems must be
// present in candidateList. If requiredIsSubset is false, then requiredItems
// may contain items that are not in candidateList.
//
// In all cases, the order of requiredItems is enforced.
func assertPresenceAndOrdering(requiredItems, candidateList []uint16, requiredIsSubset bool) bool {
superset := requiredItems
subset := candidateList
if requiredIsSubset {
superset = candidateList
subset = requiredItems
}
var j int
for _, item := range subset {
var found bool
for j < len(superset) {
if superset[j] == item {
found = true
break
}
j++
}
if j == len(superset) && !found {
return false
}
}
return true
}
func hasGreaseCiphers(cipherSuites []uint16) bool {
for _, cipher := range cipherSuites {
if _, ok := greaseCiphers[cipher]; ok {
return true
}
}
return false
}
// pool buffers so we can reuse allocations over time
var bufpool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
var greaseCiphers = map[uint16]struct{}{
0x0A0A: {},
0x1A1A: {},
0x2A2A: {},
0x3A3A: {},
0x4A4A: {},
0x5A5A: {},
0x6A6A: {},
0x7A7A: {},
0x8A8A: {},
0x9A9A: {},
0xAAAA: {},
0xBABA: {},
0xCACA: {},
0xDADA: {},
0xEAEA: {},
0xFAFA: {},
}
// Define variables used for TLS communication
const (
extensionOCSPStatusRequest = 5
extensionSupportedCurves = 10 // also called "SupportedGroups"
extensionSupportedPoints = 11
extensionHeartbeat = 15
scsvRenegotiation = 0xff
// cipher suites missing from the crypto/tls package,
// in no particular order here
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xc024
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xc028
TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x3d
TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x33
TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x39
TLS_RSA_WITH_RC4_128_MD5 = 0x4
// new PSK ciphers introduced by TLS 1.3, not (yet) in crypto/tls
// https://tlswg.github.io/tls13-spec/#rfc.appendix.A.4)
TLS_AES_128_GCM_SHA256 = 0x1301
TLS_AES_256_GCM_SHA384 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 = 0x1303
TLS_AES_128_CCM_SHA256 = 0x1304
TLS_AES_128_CCM_8_SHA256 = 0x1305
)

View File

@@ -0,0 +1,424 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"crypto/tls"
"encoding/hex"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
func TestParseClientHello(t *testing.T) {
for i, test := range []struct {
inputHex string
expected rawHelloInfo
}{
{
// curl 7.51.0 (x86_64-apple-darwin16.0) libcurl/7.51.0 SecureTransport zlib/1.2.8
inputHex: `010000a6030358a28c73a71bdfc1f09dee13fecdc58805dcce42ac44254df548f14645f7dc2c00004400ffc02cc02bc024c023c00ac009c008c030c02fc028c027c014c013c012009f009e006b0067003900330016009d009c003d003c0035002f000a00af00ae008d008c008b01000039000a00080006001700180019000b00020100000d00120010040102010501060104030203050306030005000501000000000012000000170000`,
expected: rawHelloInfo{
Version: 0x303,
CipherSuites: []uint16{255, 49196, 49195, 49188, 49187, 49162, 49161, 49160, 49200, 49199, 49192, 49191, 49172, 49171, 49170, 159, 158, 107, 103, 57, 51, 22, 157, 156, 61, 60, 53, 47, 10, 175, 174, 141, 140, 139},
Extensions: []uint16{10, 11, 13, 5, 18, 23},
CompressionMethods: []byte{0},
Curves: []tls.CurveID{23, 24, 25},
Points: []uint8{0},
},
},
{
// Chrome 56
inputHex: `010000c003031dae75222dae1433a5a283ddcde8ddabaefbf16d84f250eee6fdff48cdfff8a00000201a1ac02bc02fc02cc030cca9cca8cc14cc13c013c014009c009d002f0035000a010000777a7a0000ff010001000000000e000c0000096c6f63616c686f73740017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a0008aaaa001d001700182a2a000100`,
expected: rawHelloInfo{
Version: 0x303,
CipherSuites: []uint16{6682, 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49171, 49172, 156, 157, 47, 53, 10},
Extensions: []uint16{31354, 65281, 0, 23, 35, 13, 5, 18, 16, 30032, 11, 10, 10794},
CompressionMethods: []byte{0},
Curves: []tls.CurveID{43690, 29, 23, 24},
Points: []uint8{0},
},
},
{
// Firefox 51
inputHex: `010000bd030375f9022fc3a6562467f3540d68013b2d0b961979de6129e944efe0b35531323500001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a010000760000000e000c0000096c6f63616c686f737400170000ff01000100000a000a0008001d001700180019000b00020100002300000010000e000c02683208687474702f312e31000500050100000000ff030000000d0020001e040305030603020308040805080604010501060102010402050206020202`,
expected: rawHelloInfo{
Version: 0x303,
CipherSuites: []uint16{49195, 49199, 52393, 52392, 49196, 49200, 49162, 49161, 49171, 49172, 51, 57, 47, 53, 10},
Extensions: []uint16{0, 23, 65281, 10, 11, 35, 16, 5, 65283, 13},
CompressionMethods: []byte{0},
Curves: []tls.CurveID{29, 23, 24, 25},
Points: []uint8{0},
},
},
{
// openssl s_client (OpenSSL 0.9.8zh 14 Jan 2016)
inputHex: `0100012b03035d385236b8ca7b7946fa0336f164e76bf821ed90e8de26d97cc677671b6f36380000acc030c02cc028c024c014c00a00a500a300a1009f006b006a0069006800390038003700360088008700860085c032c02ec02ac026c00fc005009d003d00350084c02fc02bc027c023c013c00900a400a200a0009e00670040003f003e0033003200310030009a0099009800970045004400430042c031c02dc029c025c00ec004009c003c002f009600410007c011c007c00cc00200050004c012c008001600130010000dc00dc003000a00ff0201000055000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000f000101`,
expected: rawHelloInfo{
Version: 0x303,
CipherSuites: []uint16{49200, 49196, 49192, 49188, 49172, 49162, 165, 163, 161, 159, 107, 106, 105, 104, 57, 56, 55, 54, 136, 135, 134, 133, 49202, 49198, 49194, 49190, 49167, 49157, 157, 61, 53, 132, 49199, 49195, 49191, 49187, 49171, 49161, 164, 162, 160, 158, 103, 64, 63, 62, 51, 50, 49, 48, 154, 153, 152, 151, 69, 68, 67, 66, 49201, 49197, 49193, 49189, 49166, 49156, 156, 60, 47, 150, 65, 7, 49169, 49159, 49164, 49154, 5, 4, 49170, 49160, 22, 19, 16, 13, 49165, 49155, 10, 255},
Extensions: []uint16{11, 10, 35, 13, 15},
CompressionMethods: []byte{1, 0},
Curves: []tls.CurveID{23, 25, 28, 27, 24, 26, 22, 14, 13, 11, 12, 9, 10},
Points: []uint8{0, 1, 2},
},
},
} {
data, err := hex.DecodeString(test.inputHex)
if err != nil {
t.Fatalf("Test %d: Could not decode hex data: %v", i, err)
}
actual := parseRawClientHello(data)
if !reflect.DeepEqual(test.expected, actual) {
t.Errorf("Test %d: Expected %+v; got %+v", i, test.expected, actual)
}
}
}
func TestHeuristicFunctionsAndHandler(t *testing.T) {
// To test the heuristics, we assemble a collection of real
// ClientHello messages from various TLS clients, both genuine
// and intercepted. Please be sure to hex-encode them and
// document the User-Agent associated with the connection
// as well as any intercepting proxy as thoroughly as possible.
//
// If the TLS client used is not an HTTP client (e.g. s_client),
// you can leave the userAgent blank, but please use a comment
// to document crucial missing information such as client name,
// version, and platform, maybe even the date you collected
// the sample! Please group similar clients together, ordered
// by version for convenience.
// clientHello pairs a User-Agent string to its ClientHello message.
type clientHello struct {
userAgent string
helloHex string // do NOT include the header, just the ClientHello message
interception bool // if test case shows an interception, set to true
reqHeaders http.Header // if the request should set any headers to imitate a browser or proxy
}
// clientHellos groups samples of true (real) ClientHellos by the
// name of the browser that produced them. We limit the set of
// browsers to those we are programmed to protect, as well as a
// category for "Other" which contains real ClientHello messages
// from clients that we do not recognize, which may be used to
// test or imitate interception scenarios.
//
// Please group similar clients and order by version for convenience
// when adding to the test cases.
clientHellos := map[string][]clientHello{
"Chrome": {
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36",
helloHex: `010000c003031dae75222dae1433a5a283ddcde8ddabaefbf16d84f250eee6fdff48cdfff8a00000201a1ac02bc02fc02cc030cca9cca8cc14cc13c013c014009c009d002f0035000a010000777a7a0000ff010001000000000e000c0000096c6f63616c686f73740017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a0008aaaa001d001700182a2a000100`,
interception: false,
},
{
// Chrome on iOS will use iOS' TLS stack for requests that load
// the web page (apparently required by the dev ToS) but will use its
// own TLS stack for everything else, it seems.
// Chrome on iOS
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 10_0_2 like Mac OS X) AppleWebKit/602.1.50 (KHTML, like Gecko) CriOS/56.0.2924.79 Mobile/14A456 Safari/602.1",
helloHex: `010000de030358b062c509b21410a6496b5a82bfec74436cdecebe8ea1da29799939bbd3c17200002c00ffc02cc02bc024c023c00ac009c008c030c02fc028c027c014c013c012009d009c003d003c0035002f000a0100008900000014001200000f66696e6572706978656c732e636f6d000a00080006001700180019000b00020100000d00120010040102010501060104030203050306033374000000100030002e0268320568322d31360568322d31350568322d313408737064792f332e3106737064792f3308687474702f312e310005000501000000000012000000170000`,
},
{
// Chrome on iOS (requesting favicon)
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 10_0_2 like Mac OS X) AppleWebKit/602.1.50 (KHTML, like Gecko) CriOS/56.0.2924.79 Mobile/14A456 Safari/602.1",
helloHex: `010000c20303863eb64788e3b9638c261300318411cbdd8f09576d58eec1e744b6ce944f574f0000208a8acca9cca8cc14cc13c02bc02fc02cc030c013c014009c009d002f0035000a01000079baba0000ff0100010000000014001200000f66696e6572706978656c732e636f6d0017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e31000b00020100000a000a00083a3a001d001700184a4a000100`,
},
{
userAgent: "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36",
helloHex: `010000c603036f717a88212c3e9e41940f82c42acb3473e0e4a64e8f52d9af33d34e972e08a30000206a6ac02bc02fc02cc030cca9cca8cc14cc13c013c014009c009d002f0035000a0100007d7a7a0000ff0100010000000014001200000f66696e6572706978656c732e636f6d0017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a00087a7a001d001700188a8a000100`,
interception: false,
},
{
userAgent: "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36",
helloHex: `010001fc030383141d213d1bf069171843489faf808028d282c9828e1ba87637c863833c730720a67e76e152f4b704523b72317ef4587e231f02e2395e0ecac6be9f28c35e6ce600208a8ac02bc02fc02cc030cca9cca8cc14cc13c013c014009c009d002f0035000a010001931a1a0000ff0100010000000014001200000f66696e6572706978656c732e636f6d00170000002300785e85429bf1764f33111cd3ad5d1c56d765976fd962b49dbecbb6f7865e2a8d8536ad854f1fa99a8bbbf998814fee54a63a0bf162869d2bba37e9778304e7c4140825718e191b574c6246a0611de6447bdd80417f83ff9d9b7124069a9f74b90394ecb89bec5f6a1a67c1b89e50b8674782f53dd51807651a000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a00081a1a001d001700182a2a0001000015009a00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000`,
interception: false,
},
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36",
helloHex: `010000c203034166c97e2016046e0c88ad867c410d0aee470f4d9b4ec8fe41a751d2a6348e3100001c4a4ac02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a0100007dcaca0000ff0100010000000014001200000f66696e6572706978656c732e636f6d0017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a00086a6a001d001700187a7a000100`,
interception: false,
},
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36",
helloHex: `010000c203037741795e73cd5b4949f79a0dc9cccc8b006e4c0ec324f965c6fe9f0833909f0100001c7a7ac02bc02fc02cc030cca9cca8c013c014009c009d002f0035000a0100007d7a7a0000ff0100010000000014001200000f66696e6572706978656c732e636f6d0017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a00084a4a001d001700185a5a000100`,
interception: false,
},
},
"Firefox": {
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.12; rv:51.0) Gecko/20100101 Firefox/51.0",
helloHex: `010000bd030375f9022fc3a6562467f3540d68013b2d0b961979de6129e944efe0b35531323500001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a010000760000000e000c0000096c6f63616c686f737400170000ff01000100000a000a0008001d001700180019000b00020100002300000010000e000c02683208687474702f312e31000500050100000000ff030000000d0020001e040305030603020308040805080604010501060102010402050206020202`,
interception: false,
},
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.12; rv:53.0) Gecko/20100101 Firefox/53.0",
helloHex: `010001fc0303c99d54ae0628bbb9fea3833a4244c6a712cac9d7738f4930b8b9d8e2f6bd578220f7936cedb48907981c9292fb08ceee6f59bd6fddb3d4271ccd7c12380c5038ab001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a01000195001500af000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000e000c0000096c6f63616c686f737400170000ff01000100000a000a0008001d001700180019000b000201000023007886da2d41843ff42131b856982c19a545837b70e604325423a817d925e9d95bd084737682cea6b804dfb7cbe336a3b27b8d520d57520c29cfe5f4f3d3236183b84b05c18f0ca30bf598111e390086fea00d9631f1f78527277eb7838b86e73c4e5d15b55d086b1a4a8aa29f12a55126c6274bcd499bbeb23a0010000e000c02683208687474702f312e31000500050100000000000d0018001604030503060308040805080604010501060102030201`,
interception: false,
},
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.12; rv:53.0) Gecko/20100101 Firefox/53.0",
helloHex: `010000b1030365d899820b999245d571c2f7d6b850f63ad931d3c68ceb9cf5a508421a871dc500001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100006a0000000e000c0000096c6f63616c686f737400170000ff01000100000a000a0008001d001700180019000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d0018001604030503060308040805080604010501060102030201`,
interception: false,
},
{
// this was a Nightly release at the time
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.12; rv:55.0) Gecko/20100101 Firefox/55.0",
helloHex: `010001fc030331e380b7d12018e1202ef3327607203df5c5732b4fa5ab5abaf0b60034c2fb662070c836b9b89123e37f4f1074d152df438fa8ee8a0f89b036fd952f4fcc0b994f001c130113031302c02bc02fcca9cca8c02cc030c013c014002f0035000a0100019700000014001200000f63616464797365727665722e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b0002010000230078c97e7716a041e2ea824571bef26a3dff2bf50a883cd15d904ab2d17deb514f6e0a079ee7c212c000178387ffafc2e530b6df6662f570aae134330f13c458a0eaad5a96a9696f572110918740b15db1143d19aaaa706942030b433a7e6150f62b443c0564e5b8f7ee9577bf3bf7faec8c67425b648ab54d880010000e000c02683208687474702f312e310005000501000000000028006b0069001d0020aee6e596155ee6f79f943e81ceabe0979d27fbbb8b9189ccb2ebc75226351f32001700410421875a44e510decac11ef1d7cfddd4dfe105d5cd3a2d42fba03ebde23e51e8ce65bda1b48be82d4848d1db2bfce68e94092e925a9ce0dbf5df35479558108489002b0009087f12030303020301000d0018001604030503060308040805080604010501060102030201002d000201010015002500000000000000000000000000000000000000000000000000000000000000000000000000`,
interception: false,
},
{
// Firefox on Fedora (RedHat) doesn't include ECC ciphers because of patent liabilities
userAgent: "Mozilla/5.0 (X11; Fedora; Linux x86_64; rv:53.0) Gecko/20100101 Firefox/53.0",
helloHex: `010000b70303f5280b74d617d42e39fd77b78a2b537b1d7787ce4fcbcf3604c9fbcd677c6c5500001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100007000000014001200000f66696e6572706978656c732e636f6d00170000ff01000100000a000a0008001d001700180019000b00020100002300000010000e000c02683208687474702f312e31000500050100000000000d0018001604030503060308040805080604010501060102030201`,
interception: false,
},
},
"Edge": {
{
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
helloHex: `010000bd030358a3c9bf05f734842e189fb6ce653b67b846e990bc1fc5fb8c397874d06020f1000038c02cc02bc030c02f009f009ec024c023c028c027c00ac009c014c01300390033009d009c003d003c0035002f000a006a00400038003200130100005c000500050100000000000a00080006001d00170018000b00020100000d00140012040105010201040305030203020206010603002300000010000e000c02683208687474702f312e310017000055000006000100020002ff01000100`,
interception: false,
},
},
"Safari": {
{
userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3) AppleWebKit/602.4.8 (KHTML, like Gecko) Version/10.0.3 Safari/602.4.8",
helloHex: `010000d2030358a295b513c8140c6ff880f4a8a73cc830ed2dab2c4f2068eb365228d828732e00002600ffc02cc02bc024c023c00ac009c030c02fc028c027c014c013009d009c003d003c0035002f010000830000000e000c0000096c6f63616c686f7374000a00080006001700180019000b00020100000d00120010040102010501060104030203050306033374000000100030002e0268320568322d31360568322d31350568322d313408737064792f332e3106737064792f3308687474702f312e310005000501000000000012000000170000`,
interception: false,
},
{
// I think this was iOS 11 beta
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 11_0 like Mac OS X) AppleWebKit/604.1.28 (KHTML, like Gecko) Version/11.0 Mobile/15A5318g Safari/604.1",
helloHex: `010000e10303be294e11847ba01301e0bb6129f4a0d66344602141a8f0a1ab0750a1db145755000028c02cc02bc024c023cca9c00ac009c030c02fc028c027cca8c014c013009d009c003d003c0035002f01000090ff0100010000000014001200000f66696e6572706978656c732e636f6d00170000000d00140012040308040401050308050501080606010201000500050100000000337400000012000000100030002e0268320568322d31360568322d31350568322d313408737064792f332e3106737064792f3308687474702f312e31000b00020100000a00080006001d00170018`,
interception: false,
},
{
// iOS 11 stable
userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 11_0 like Mac OS X) AppleWebKit/604.1.38 (KHTML, like Gecko) Version/11.0 Mobile/15A372 Safari/604.1",
helloHex: `010000dc030327fafb16708fcbe489fda332260d32b1a22bea6672a72b5e61d7b9963df1b10d000028c02cc02bc024c023c00ac009cca9c030c02fc028c027c014c013cca8009d009c003d003c0035002f0100008bff010001000000000f000d00000a6d69746d2e776174636800170000000d00140012040308040401050308050501080606010201000500050100000000337400000012000000100030002e0268320568322d31360568322d31350568322d313408737064792f332e3106737064792f3308687474702f312e31000b00020100000a00080006001d00170018`,
interception: false,
},
},
"Tor": {
{
userAgent: "Mozilla/5.0 (Windows NT 6.1; rv:45.0) Gecko/20100101 Firefox/45.0",
helloHex: `010000a40303137f05d4151f2d9095aee4254416d9dce73d6a1d857e8097ea20d021c04a7a81000016c02bc02fc00ac009c013c01400330039002f0035000a0100006500000014001200000f66696e6572706978656c732e636f6dff01000100000a00080006001700180019000b00020100337400000010000b000908687474702f312e31000500050100000000000d001600140401050106010201040305030603020304020202`,
interception: false,
},
{
userAgent: "Mozilla/5.0 (Windows NT 6.1; rv:52.0) Gecko/20100101 Firefox/52.0",
helloHex: `010000b4030322e1f3aff4c37caba303c2ce53ba1689b3e70117a46f413d44f70a74cb6a496100001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100006d00000014001200000f66696e6572706978656c732e636f6d00170000ff01000100000a000a0008001d001700180019000b000201000010000b000908687474702f312e31000500050100000000ff030000000d0018001604030503060308040805080604010501060102030201`,
interception: false,
},
},
"Other": { // these are either non-browser clients or intercepted client hellos
{
// openssl s_client (OpenSSL 0.9.8zh 14 Jan 2016) - NOT an interception, but not a browser either
helloHex: `0100012b03035d385236b8ca7b7946fa0336f164e76bf821ed90e8de26d97cc677671b6f36380000acc030c02cc028c024c014c00a00a500a300a1009f006b006a0069006800390038003700360088008700860085c032c02ec02ac026c00fc005009d003d00350084c02fc02bc027c023c013c00900a400a200a0009e00670040003f003e0033003200310030009a0099009800970045004400430042c031c02dc029c025c00ec004009c003c002f009600410007c011c007c00cc00200050004c012c008001600130010000dc00dc003000a00ff0201000055000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000f000101`,
// NOTE: This test case is not actually an interception, but s_client is not a browser
// or any client we support MITM checking for, either. Since it advertises heartbeat,
// our heuristics still flag it as a MITM.
interception: true,
},
{
// curl 7.51.0 (x86_64-apple-darwin16.0) libcurl/7.51.0 SecureTransport zlib/1.2.8
userAgent: "curl/7.51.0",
helloHex: `010000a6030358a28c73a71bdfc1f09dee13fecdc58805dcce42ac44254df548f14645f7dc2c00004400ffc02cc02bc024c023c00ac009c008c030c02fc028c027c014c013c012009f009e006b0067003900330016009d009c003d003c0035002f000a00af00ae008d008c008b01000039000a00080006001700180019000b00020100000d00120010040102010501060104030203050306030005000501000000000012000000170000`,
interception: false,
},
{
// Avast 17.1.2286 (Feb. 2017) on Windows 10 x64 build 14393, intercepting Edge
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
helloHex: `010000ce0303b418fdc4b6cf6436a5e2bfb06b96ed5faa7285c20c7b49341a78be962a9dc40000003ac02cc02bc030c02f009f009ec024c023c028c027c00ac009c014c01300390033009d009c003d003c0035002f000a006a004000380032001300ff0100006b00000014001200000f66696e6572706978656c732e636f6d000b000403000102000a00080006001d0017001800230000000d001400120401050102010403050302030202060106030005000501000000000010000e000c02683208687474702f312e310016000000170000`,
interception: true,
},
{
// Kaspersky Internet Security 17.0.0.611 on Windows 10 x64 build 14393, intercepting Edge
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
helloHex: `010000eb030361ce302bf4b0d5adf1ff30b2cf433c4a4b68f33e07b2651695e7ae6ec3cf126400003ac02cc02bc030c02f009f009ec024c023c028c027c00ac009c014c01300390033009d009c003d003c0035002f000a006a004000380032001300ff0100008800000014001200000f66696e6572706978656c732e636f6d000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000500050100000000000f0001010010000e000c02683208687474702f312e31`,
interception: true,
},
{
// Kaspersky Internet Security 17.0.0.611 on Windows 10 x64 build 14393, intercepting Firefox 51
userAgent: "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:51.0) Gecko/20100101 Firefox/51.0",
helloHex: `010001fc0303768e3f9ea75194c7cb03d23e8e6371b95fb696d339b797be57a634309ec98a42200f2a7554098364b7f05d21a8c7f43f31a893a4fc5670051020408c8e4dc234dd001cc02bc02fc02cc030c00ac009c013c01400330039002f0035000a00ff0100019700000014001200000f66696e6572706978656c732e636f6d000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230078bf4e244d4de3d53c6331edda9672dfc4a17aae92b671e86da1368b1b5ae5324372817d8f3b7ffe1a7a1537a5049b86cd7c44863978c1e615b005942755da20fc3a4e34a16f78034aa3b1cffcef95f81a0995c522a53b0e95a4f98db84c43359d93d8647b2de2a69f3ebdcfc6bca452730cbd00179226dedf000d0020001e060106020603050105020503040104020403030103020303020102020203000500050100000000000f0001010010000e000c02683208687474702f312e3100150093000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000`,
interception: true,
},
{
// Kaspersky Internet Security 17.0.0.611 on Windows 10 x64 build 14393, intercepting Chrome 56
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36",
helloHex: `010000c903033481e7af24e647ba5a79ec97e9264c1a1f990cf842f50effe22be52130d5af82000018c02bc02fc02cc030c013c014009c009d002f0035000a00ff0100008800000014001200000f66696e6572706978656c732e636f6d000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000500050100000000000f0001010010000e000c02683208687474702f312e31`,
interception: true,
},
{
// AVG 17.1.3006 (build 17.1.3354.20) on Windows 10 x64 build 14393, intercepting Edge
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
helloHex: `010000ca0303fd83091207161eca6b4887db50587109c50e463beb190362736b1fcf9e05f807000036c02cc02bc030c02f009f009ec024c023c028c027c00ac009c014c01300390033009d009c003d003c0035002f006a00400038003200ff0100006b00000014001200000f66696e6572706978656c732e636f6d000b000403000102000a00080006001d0017001800230000000d001400120401050102010403050302030202060106030005000501000000000010000e000c02683208687474702f312e310016000000170000`,
interception: true,
},
{
// IE 11 on Windows 7, this connection was intercepted by Blue Coat
// no sensible User-Agent value, since Blue Coat changes it to something super generic
// By the way, here's another reason we hate Blue Coat: they break TLS 1.3:
// https://twitter.com/FiloSottile/status/835269932929667072
helloHex: `010000b1030358a3f3bae627f464da8cb35976b88e9119640032d41e62a107d608ed8d3e62b9000034c028c027c014c013009f009e009d009cc02cc02bc024c023c00ac009003d003c0035002f006a004000380032000a0013000500040100005400000014001200000f66696e6572706978656c732e636f6d000500050100000000000a00080006001700180019000b00020100000d0014001206010603040105010201040305030203020200170000ff01000100`,
interception: true,
reqHeaders: http.Header{"X-Bluecoat-Via": {"66808702E9A2CF4"}}, // actual field name would be "X-BlueCoat-Via" but Go canonicalizes field names
},
{
// Firefox 51.0.1 being intercepted by burp 1.7.17
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:51.0) Gecko/20100101 Firefox/51.0",
helloHex: `010000d8030358a92f4daca95acc2f6a10a9c50d736135eae39406d3090238464540d482677600003ac023c027003cc025c02900670040c009c013002fc004c00e00330032c02bc02f009cc02dc031009e00a2c008c012000ac003c00d0016001300ff01000075000a0034003200170001000300130015000600070009000a0018000b000c0019000d000e000f001000110002001200040005001400080016000b00020100000d00180016060306010503050104030401040202030201020201010000001700150000126a61677561722e6b796877616e612e6f7267`,
interception: true,
},
{
// Chrome 56 on Windows 10 being intercepted by Fortigate (on some public school network); note: I had to enable TLS 1.0 for this test (proxy was issuing a SHA-1 cert to client)
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/56.0.2924.87 Safari/537.36",
helloHex: `010000e5030158ac612125c83bae95282113b2a4c572cf613c160d234350fb6d0ddce879ffec000064003300320039003800160013c013c009c014c00ac012c008002f0035000a00150012003d003c00670040006b006ac011c0070096009a009900410084004500440088008700ba00be00bd00c000c400c3c03cc044c042c03dc045c04300090005000400ff01000058000a003600340000000100020003000400050006000700080009000a000b000c000d000e000f0010001100120013001400150016001700180019000b0002010000000014001200000f66696e6572706978656c732e636f6d`,
interception: true,
},
{
// IE 11 on Windows 10, intercepted by Fortigate (same firewall as above)
userAgent: "Mozilla/5.0 (Windows NT 10.0; WOW64; Trident/7.0; rv:11.0) like Gecko",
helloHex: `010000e5030158ac634c5278d7b17421f23a64cc91d68c470c6b247322fe867ba035b373d05c000064003300320039003800160013c013c009c014c00ac012c008002f0035000a00150012003d003c00670040006b006ac011c0070096009a009900410084004500440088008700ba00be00bd00c000c400c3c03cc044c042c03dc045c04300090005000400ff01000058000a003600340000000100020003000400050006000700080009000a000b000c000d000e000f0010001100120013001400150016001700180019000b0002010000000014001200000f66696e6572706978656c732e636f6d`,
interception: true,
},
{
// Edge 38.14393.0.0 on Windows 10, intercepted by Fortigate (same as above)
userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
helloHex: `010000e5030158ac6421a45794b8ade6a0ac6c910cde0f99c49bb1ba737b88638ec8dcf0d077000064003300320039003800160013c013c009c014c00ac012c008002f0035000a00150012003d003c00670040006b006ac011c0070096009a009900410084004500440088008700ba00be00bd00c000c400c3c03cc044c042c03dc045c04300090005000400ff01000058000a003600340000000100020003000400050006000700080009000a000b000c000d000e000f0010001100120013001400150016001700180019000b0002010000000014001200000f66696e6572706978656c732e636f6d`,
interception: true,
},
{
// Firefox 50.0.1 on Windows 10, intercepted by Fortigate (same as above)
userAgent: "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:50.0) Gecko/20100101 Firefox/50.0",
helloHex: `010000e5030158ac64e40495e77b7baf2031281451620bfe354b0c37521ebc0a40f5dc0c0cb6000064003300320039003800160013c013c009c014c00ac012c008002f0035000a00150012003d003c00670040006b006ac011c0070096009a009900410084004500440088008700ba00be00bd00c000c400c3c03cc044c042c03dc045c04300090005000400ff01000058000a003600340000000100020003000400050006000700080009000a000b000c000d000e000f0010001100120013001400150016001700180019000b0002010000000014001200000f66696e6572706978656c732e636f6d`,
interception: true,
},
},
}
for client, chs := range clientHellos {
for i, ch := range chs {
hello, err := hex.DecodeString(ch.helloHex)
if err != nil {
t.Errorf("[%s] Test %d: Error decoding ClientHello: %v", client, i, err)
continue
}
parsed := parseRawClientHello(hello)
isChrome := parsed.looksLikeChrome()
isFirefox := parsed.looksLikeFirefox()
isSafari := parsed.looksLikeSafari()
isEdge := parsed.looksLikeEdge()
isTor := parsed.looksLikeTor()
// we want each of the heuristic functions to be as
// exclusive but as low-maintenance as possible;
// in other words, if one returns true, the others
// should return false, with as little logic as possible,
// but with enough logic to force TLS proxies to do a
// good job preserving characterstics of the handshake.
if (isChrome && (isFirefox || isSafari || isEdge || isTor)) ||
(isFirefox && (isChrome || isSafari || isEdge || isTor)) ||
(isSafari && (isChrome || isFirefox || isEdge || isTor)) ||
(isEdge && (isChrome || isFirefox || isSafari || isTor)) ||
(isTor && (isChrome || isFirefox || isSafari || isEdge)) {
t.Errorf("[%s] Test %d: Multiple fingerprinting functions matched: "+
"Chrome=%v Firefox=%v Safari=%v Edge=%v Tor=%v\n\tparsed hello dec: %+v\n",
client, i, isChrome, isFirefox, isSafari, isEdge, isTor, parsed)
}
// test the handler and detection results
var got, checked bool
want := ch.interception
handler := &tlsHandler{
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, checked = r.Context().Value(MitmCtxKey).(bool)
}),
listener: newTLSListener(nil, nil),
}
handler.listener.helloInfos[""] = parsed
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
r.Header.Set("User-Agent", ch.userAgent)
if ch.reqHeaders != nil {
for field, values := range ch.reqHeaders {
r.Header[field] = values // NOTE: field names not standardized when setting directly like this!
}
}
handler.ServeHTTP(w, r)
if got != want {
t.Errorf("[%s] Test %d: Expected MITM=%v but got %v (type assertion OK (checked)=%v)",
client, i, want, got, checked)
t.Errorf("[%s] Test %d: Looks like Chrome=%v Firefox=%v Safari=%v Edge=%v Tor=%v\n\tparsed hello dec: %+v\n",
client, i, isChrome, isFirefox, isSafari, isEdge, isTor, parsed)
}
}
}
}
func TestGetVersion(t *testing.T) {
for i, test := range []struct {
UserAgent string
SoftwareName string
Version float64
}{
{
UserAgent: "Mozilla/5.0 (Windows NT 6.1; rv:45.0) Gecko/20100101 Firefox/45.0",
SoftwareName: "Firefox",
Version: 45.0,
},
{
UserAgent: "Mozilla/5.0 (Windows NT 6.1; rv:45.0) Gecko/20100101 Firefox/45.0 more_stuff_here",
SoftwareName: "Firefox",
Version: 45.0,
},
{
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
SoftwareName: "Safari",
Version: 537.36,
},
{
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
SoftwareName: "Chrome",
Version: 51.0270479,
},
{
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
SoftwareName: "Mozilla",
Version: 5.0,
},
{
UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.79 Safari/537.36 Edge/14.14393",
SoftwareName: "curl",
Version: -1,
},
} {
actual := getVersion(test.UserAgent, test.SoftwareName)
if actual != test.Version {
t.Errorf("Test [%d]: Expected version=%f, got version=%f for %s in '%s'",
i, test.Version, actual, test.SoftwareName, test.UserAgent)
}
}
}

View File

@@ -0,0 +1,67 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"net/http"
"path"
"strings"
)
// Path represents a URI path. It should usually be
// set to the value of a request path.
type Path string
// Matches checks to see if base matches p. The correct
// usage of this method sets p as the request path, and
// base as a Caddyfile (user-defined) rule path.
//
// Path matching will probably not always be a direct
// comparison; this method assures that paths can be
// easily and consistently matched.
//
// Multiple slashes are collapsed/merged. See issue #1859.
func (p Path) Matches(base string) bool {
if base == "/" || base == "" {
return true
}
// sanitize the paths for comparison, very important
// (slightly lossy if the base path requires multiple
// consecutive forward slashes, since those will be merged)
pHasTrailingSlash := strings.HasSuffix(string(p), "/")
baseHasTrailingSlash := strings.HasSuffix(base, "/")
p = Path(path.Clean(string(p)))
base = path.Clean(base)
if pHasTrailingSlash {
p += "/"
}
if baseHasTrailingSlash {
base += "/"
}
if CaseSensitivePath {
return strings.HasPrefix(string(p), base)
}
return strings.HasPrefix(strings.ToLower(string(p)), strings.ToLower(base))
}
// PathMatcher is a Path RequestMatcher.
type PathMatcher string
// Match satisfies RequestMatcher.
func (p PathMatcher) Match(r *http.Request) bool {
return Path(r.URL.Path).Matches(string(p))
}

View File

@@ -0,0 +1,146 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import "testing"
func TestPathMatches(t *testing.T) {
for i, testcase := range []struct {
reqPath Path
rulePath string // or "base path" as in Caddyfile docs
shouldMatch bool
caseInsensitive bool
}{
{
reqPath: "/",
rulePath: "/",
shouldMatch: true,
},
{
reqPath: "/foo/bar",
rulePath: "/foo",
shouldMatch: true,
},
{
reqPath: "/foobar",
rulePath: "/foo/",
shouldMatch: false,
},
{
reqPath: "/foobar",
rulePath: "/foo/bar",
shouldMatch: false,
},
{
reqPath: "/foo/",
rulePath: "/foo/",
shouldMatch: true,
},
{
reqPath: "/Foobar",
rulePath: "/Foo",
shouldMatch: true,
},
{
reqPath: "/FooBar",
rulePath: "/Foo",
shouldMatch: true,
},
{
reqPath: "/foobar",
rulePath: "/FooBar",
shouldMatch: true,
caseInsensitive: true,
},
{
reqPath: "",
rulePath: "/", // a lone forward slash means to match all requests (see issue #1645) - many future test cases related to this issue
shouldMatch: true,
},
{
reqPath: "foobar.php",
rulePath: "/",
shouldMatch: true,
},
{
reqPath: "",
rulePath: "",
shouldMatch: true,
},
{
reqPath: "/foo/bar",
rulePath: "",
shouldMatch: true,
},
{
reqPath: "/foo/bar",
rulePath: "",
shouldMatch: true,
},
{
reqPath: "no/leading/slash",
rulePath: "/",
shouldMatch: true,
},
{
reqPath: "no/leading/slash",
rulePath: "/no/leading/slash",
shouldMatch: false,
},
{
reqPath: "no/leading/slash",
rulePath: "",
shouldMatch: true,
},
{
// see issue #1859
reqPath: "//double-slash",
rulePath: "/double-slash",
shouldMatch: true,
},
{
reqPath: "/double//slash",
rulePath: "/double/slash",
shouldMatch: true,
},
{
reqPath: "//more/double//slashes",
rulePath: "/more/double/slashes",
shouldMatch: true,
},
{
reqPath: "/path/../traversal",
rulePath: "/traversal",
shouldMatch: true,
},
{
reqPath: "/path/../traversal",
rulePath: "/path",
shouldMatch: false,
},
{
reqPath: "/keep-slashes/http://something/foo/bar",
rulePath: "/keep-slashes/http://something",
shouldMatch: true,
},
} {
CaseSensitivePath = !testcase.caseInsensitive
if got, want := testcase.reqPath.Matches(testcase.rulePath), testcase.shouldMatch; got != want {
t.Errorf("Test %d: For request path '%s' and base path '%s': expected %v, got %v",
i, testcase.reqPath, testcase.rulePath, want, got)
}
}
}

View File

@@ -0,0 +1,701 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"crypto/tls"
"flag"
"fmt"
"log"
"net"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/staticfiles"
"github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry"
)
const serverType = "http"
func init() {
flag.StringVar(&HTTPPort, "http-port", HTTPPort, "Default port to use for HTTP")
flag.StringVar(&HTTPSPort, "https-port", HTTPSPort, "Default port to use for HTTPS")
flag.StringVar(&Host, "host", DefaultHost, "Default host")
flag.StringVar(&Port, "port", DefaultPort, "Default port")
flag.StringVar(&Root, "root", DefaultRoot, "Root path of default site")
flag.DurationVar(&GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown")
flag.BoolVar(&HTTP2, "http2", true, "Use HTTP/2")
flag.BoolVar(&QUIC, "quic", false, "Use experimental QUIC")
caddy.RegisterServerType(serverType, caddy.ServerType{
Directives: func() []string { return directives },
DefaultInput: func() caddy.Input {
if Port == DefaultPort && Host != "" {
// by leaving the port blank in this case we give auto HTTPS
// a chance to set the port to 443 for us
return caddy.CaddyfileInput{
Contents: []byte(fmt.Sprintf("%s\nroot %s", Host, Root)),
ServerTypeName: serverType,
}
}
return caddy.CaddyfileInput{
Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, Port, Root)),
ServerTypeName: serverType,
}
},
NewContext: newContext,
})
caddy.RegisterCaddyfileLoader("short", caddy.LoaderFunc(shortCaddyfileLoader))
caddy.RegisterParsingCallback(serverType, "root", hideCaddyfile)
caddy.RegisterParsingCallback(serverType, "tls", activateHTTPS)
caddytls.RegisterConfigGetter(serverType, func(c *caddy.Controller) *caddytls.Config { return GetConfig(c).TLS })
// disable the caddytls package reporting ClientHellos
// to telemetry, since our MITM detector does this but
// with more information than the standard lib provides
// (as of May 2018)
caddytls.ClientHelloTelemetry = false
}
// hideCaddyfile hides the source/origin Caddyfile if it is within the
// site root. This function should be run after parsing the root directive.
func hideCaddyfile(cctx caddy.Context) error {
ctx := cctx.(*httpContext)
for _, cfg := range ctx.siteConfigs {
// if no Caddyfile exists exit.
if cfg.originCaddyfile == "" {
return nil
}
absRoot, err := filepath.Abs(cfg.Root)
if err != nil {
return err
}
absOriginCaddyfile, err := filepath.Abs(cfg.originCaddyfile)
if err != nil {
return err
}
if strings.HasPrefix(absOriginCaddyfile, absRoot) {
cfg.HiddenFiles = append(cfg.HiddenFiles, filepath.ToSlash(strings.TrimPrefix(absOriginCaddyfile, absRoot)))
}
}
return nil
}
func newContext(inst *caddy.Instance) caddy.Context {
return &httpContext{instance: inst, keysToSiteConfigs: make(map[string]*SiteConfig)}
}
type httpContext struct {
instance *caddy.Instance
// keysToSiteConfigs maps an address at the top of a
// server block (a "key") to its SiteConfig. Not all
// SiteConfigs will be represented here, only ones
// that appeared in the Caddyfile.
keysToSiteConfigs map[string]*SiteConfig
// siteConfigs is the master list of all site configs.
siteConfigs []*SiteConfig
}
func (h *httpContext) saveConfig(key string, cfg *SiteConfig) {
h.siteConfigs = append(h.siteConfigs, cfg)
h.keysToSiteConfigs[key] = cfg
}
// InspectServerBlocks make sure that everything checks out before
// executing directives and otherwise prepares the directives to
// be parsed and executed.
func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
siteAddrs := make(map[string]string)
// For each address in each server block, make a new config
for _, sb := range serverBlocks {
for _, key := range sb.Keys {
addr, err := standardizeAddress(key)
if err != nil {
return serverBlocks, err
}
addr = addr.Normalize()
key = addr.Key()
if _, dup := h.keysToSiteConfigs[key]; dup {
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
}
// Fill in address components from command line so that middleware
// have access to the correct information during setup
if addr.Host == "" && Host != DefaultHost {
addr.Host = Host
}
if addr.Port == "" && Port != DefaultPort {
addr.Port = Port
}
// Make sure the adjusted site address is distinct
addrCopy := addr // make copy so we don't disturb the original, carefully-parsed address struct
if addrCopy.Port == "" && Port == DefaultPort {
addrCopy.Port = Port
}
addrStr := addrCopy.String()
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
err := fmt.Errorf("duplicate site address: %s", addrStr)
if (addrCopy.Host == Host && Host != DefaultHost) ||
(addrCopy.Port == Port && Port != DefaultPort) {
err = fmt.Errorf("site defined as %s is a duplicate of %s because of modified "+
"default host and/or port values (usually via -host or -port flags)", key, otherSiteKey)
}
return serverBlocks, err
}
siteAddrs[addrStr] = key
// If default HTTP or HTTPS ports have been customized,
// make sure the ACME challenge ports match
var altHTTPPort, altTLSSNIPort string
if HTTPPort != DefaultHTTPPort {
altHTTPPort = HTTPPort
}
if HTTPSPort != DefaultHTTPSPort {
altTLSSNIPort = HTTPSPort
}
// Make our caddytls.Config, which has a pointer to the
// instance's certificate cache and enough information
// to use automatic HTTPS when the time comes
caddytlsConfig := caddytls.NewConfig(h.instance)
caddytlsConfig.Hostname = addr.Host
caddytlsConfig.AltHTTPPort = altHTTPPort
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
// Save the config to our master list, and key it for lookups
cfg := &SiteConfig{
Addr: addr,
Root: Root,
TLS: caddytlsConfig,
originCaddyfile: sourceFile,
IndexPages: staticfiles.DefaultIndexPages,
}
h.saveConfig(key, cfg)
}
}
// For sites that have gzip (which gets chained in
// before the error handler) we should ensure that the
// errors directive also appears so error pages aren't
// written after the gzip writer is closed. See #616.
for _, sb := range serverBlocks {
_, hasGzip := sb.Tokens["gzip"]
_, hasErrors := sb.Tokens["errors"]
if hasGzip && !hasErrors {
sb.Tokens["errors"] = []caddyfile.Token{{Text: "errors"}}
}
}
return serverBlocks, nil
}
// MakeServers uses the newly-created siteConfigs to
// create and return a list of server instances.
func (h *httpContext) MakeServers() ([]caddy.Server, error) {
// make a rough estimate as to whether we're in a "production
// environment/system" - start by assuming that most production
// servers will set their default CA endpoint to a public,
// trusted CA (obviously not a perfect hueristic)
var looksLikeProductionCA bool
for _, publicCAEndpoint := range caddytls.KnownACMECAs {
if strings.Contains(caddytls.DefaultCAUrl, publicCAEndpoint) {
looksLikeProductionCA = true
break
}
}
// Iterate each site configuration and make sure that:
// 1) TLS is disabled for explicitly-HTTP sites (necessary
// when an HTTP address shares a block containing tls)
// 2) if QUIC is enabled, TLS ClientAuth is not, because
// currently, QUIC does not support ClientAuth (TODO:
// revisit this when our QUIC implementation supports it)
// 3) if TLS ClientAuth is used, StrictHostMatching is on
var atLeastOneSiteLooksLikeProduction bool
for _, cfg := range h.siteConfigs {
// see if all the addresses (both sites and
// listeners) are loopback to help us determine
// if this is a "production" instance or not
if !atLeastOneSiteLooksLikeProduction {
if !caddy.IsLoopback(cfg.Addr.Host) &&
!caddy.IsLoopback(cfg.ListenHost) &&
(caddytls.QualifiesForManagedTLS(cfg) ||
caddytls.HostQualifies(cfg.Addr.Host)) {
atLeastOneSiteLooksLikeProduction = true
}
}
// make sure TLS is disabled for explicitly-HTTP sites
// (necessary when HTTP address shares a block containing tls)
if !cfg.TLS.Enabled {
continue
}
if cfg.Addr.Port == HTTPPort || cfg.Addr.Scheme == "http" {
cfg.TLS.Enabled = false
log.Printf("[WARNING] TLS disabled for %s", cfg.Addr)
} else if cfg.Addr.Scheme == "" {
// set scheme to https ourselves, since TLS is enabled
// and it was not explicitly set to something else. this
// makes it appear as "https" when we print the list of
// running sites; otherwise "http" would be assumed which
// is incorrect for this site.
cfg.Addr.Scheme = "https"
}
if cfg.Addr.Port == "" && ((!cfg.TLS.Manual && !cfg.TLS.SelfSigned) || cfg.TLS.OnDemand) {
// this is vital, otherwise the function call below that
// sets the listener address will use the default port
// instead of 443 because it doesn't know about TLS.
cfg.Addr.Port = HTTPSPort
}
if cfg.TLS.ClientAuth != tls.NoClientCert {
if QUIC {
return nil, fmt.Errorf("cannot enable TLS client authentication with QUIC, because QUIC does not yet support it")
}
// this must be enabled so that a client cannot connect
// using SNI for another site on this listener that
// does NOT require ClientAuth, and then send HTTP
// requests with the Host header of this site which DOES
// require client auth, thus bypassing it...
cfg.StrictHostMatching = true
}
}
// we must map (group) each config to a bind address
groups, err := groupSiteConfigsByListenAddr(h.siteConfigs)
if err != nil {
return nil, err
}
// then we create a server for each group
var servers []caddy.Server
for addr, group := range groups {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
}
// NOTE: This value is only a "good guess". Quite often, development
// environments will use internal DNS or a local hosts file to serve
// real-looking domains in local development. We can't easily tell
// which without doing a DNS lookup, so this guess is definitely naive,
// and if we ever want a better guess, we will have to do DNS lookups.
deploymentGuess := "dev"
if looksLikeProductionCA && atLeastOneSiteLooksLikeProduction {
deploymentGuess = "prod"
}
telemetry.Set("http_deployment_guess", deploymentGuess)
telemetry.Set("http_num_sites", len(h.siteConfigs))
return servers, nil
}
// normalizedKey returns "normalized" key representation:
// scheme and host names are lowered, everything else stays the same
func normalizedKey(key string) string {
addr, err := standardizeAddress(key)
if err != nil {
return key
}
return addr.Normalize().Key()
}
// GetConfig gets the SiteConfig that corresponds to c.
// If none exist (should only happen in tests), then a
// new, empty one will be created.
func GetConfig(c *caddy.Controller) *SiteConfig {
ctx := c.Context().(*httpContext)
key := normalizedKey(c.Key)
if cfg, ok := ctx.keysToSiteConfigs[key]; ok {
return cfg
}
// we should only get here during tests because directive
// actions typically skip the server blocks where we make
// the configs
cfg := &SiteConfig{Root: Root, TLS: new(caddytls.Config), IndexPages: staticfiles.DefaultIndexPages}
ctx.saveConfig(key, cfg)
return cfg
}
// shortCaddyfileLoader loads a Caddyfile if positional arguments are
// detected, or, in other words, if un-named arguments are provided to
// the program. A "short Caddyfile" is one in which each argument
// is a line of the Caddyfile. The default host and port are prepended
// according to the Host and Port values.
func shortCaddyfileLoader(serverType string) (caddy.Input, error) {
if flag.NArg() > 0 && serverType == "http" {
confBody := fmt.Sprintf("%s:%s\n%s", Host, Port, strings.Join(flag.Args(), "\n"))
return caddy.CaddyfileInput{
Contents: []byte(confBody),
Filepath: "args",
ServerTypeName: serverType,
}, nil
}
return nil, nil
}
// groupSiteConfigsByListenAddr groups site configs by their listen
// (bind) address, so sites that use the same listener can be served
// on the same server instance. The return value maps the listen
// address (what you pass into net.Listen) to the list of site configs.
// This function does NOT vet the configs to ensure they are compatible.
func groupSiteConfigsByListenAddr(configs []*SiteConfig) (map[string][]*SiteConfig, error) {
groups := make(map[string][]*SiteConfig)
for _, conf := range configs {
// We would add a special case here so that localhost addresses
// bind to 127.0.0.1 if conf.ListenHost is not already set, which
// would prevent outsiders from even connecting; but that was problematic:
// https://caddy.community/t/wildcard-virtual-domains-with-wildcard-roots/221/5?u=matt
if conf.Addr.Port == "" {
conf.Addr.Port = Port
}
addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.ListenHost, conf.Addr.Port))
if err != nil {
return nil, err
}
addrstr := addr.String()
groups[addrstr] = append(groups[addrstr], conf)
}
return groups, nil
}
// Address represents a site address. It contains
// the original input value, and the component
// parts of an address. The component parts may be
// updated to the correct values as setup proceeds,
// but the original value should never be changed.
type Address struct {
Original, Scheme, Host, Port, Path string
}
// String returns a human-friendly print of the address.
func (a Address) String() string {
if a.Host == "" && a.Port == "" {
return ""
}
scheme := a.Scheme
if scheme == "" {
if a.Port == HTTPSPort {
scheme = "https"
} else {
scheme = "http"
}
}
s := scheme
if s != "" {
s += "://"
}
s += a.Host
if a.Port != "" &&
((scheme == "https" && a.Port != DefaultHTTPSPort) ||
(scheme == "http" && a.Port != DefaultHTTPPort)) {
s += ":" + a.Port
}
if a.Path != "" {
s += a.Path
}
return s
}
// VHost returns a sensible concatenation of Host:Port/Path from a.
// It's basically the a.Original but without the scheme.
func (a Address) VHost() string {
if idx := strings.Index(a.Original, "://"); idx > -1 {
return a.Original[idx+3:]
}
return a.Original
}
// Normalize normalizes URL: turn scheme and host names into lower case
func (a Address) Normalize() Address {
path := a.Path
if !CaseSensitivePath {
path = strings.ToLower(path)
}
return Address{
Original: a.Original,
Scheme: strings.ToLower(a.Scheme),
Host: strings.ToLower(a.Host),
Port: a.Port,
Path: path,
}
}
// Key is similar to String, just replaces scheme and host values with modified values.
// Unlike String it doesn't add anything default (scheme, port, etc)
func (a Address) Key() string {
res := ""
if a.Scheme != "" {
res += a.Scheme + "://"
}
if a.Host != "" {
res += a.Host
}
if a.Port != "" {
if strings.HasPrefix(a.Original[len(res):], ":"+a.Port) {
// insert port only if the original has its own explicit port
res += ":" + a.Port
}
}
if a.Path != "" {
res += a.Path
}
return res
}
// standardizeAddress parses an address string into a structured format with separate
// scheme, host, port, and path portions, as well as the original input string.
func standardizeAddress(str string) (Address, error) {
input := str
// Split input into components (prepend with // to assert host by default)
if !strings.Contains(str, "//") && !strings.HasPrefix(str, "/") {
str = "//" + str
}
u, err := url.Parse(str)
if err != nil {
return Address{}, err
}
// separate host and port
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host, port, err = net.SplitHostPort(u.Host + ":")
if err != nil {
host = u.Host
}
}
// see if we can set port based off scheme
if port == "" {
if u.Scheme == "http" {
port = HTTPPort
} else if u.Scheme == "https" {
port = HTTPSPort
}
}
// repeated or conflicting scheme is confusing, so error
if u.Scheme != "" && (port == "http" || port == "https") {
return Address{}, fmt.Errorf("[%s] scheme specified twice in address", input)
}
// error if scheme and port combination violate convention
if (u.Scheme == "http" && port == HTTPSPort) || (u.Scheme == "https" && port == HTTPPort) {
return Address{}, fmt.Errorf("[%s] scheme and port violate convention", input)
}
// standardize http and https ports to their respective port numbers
if port == "http" {
u.Scheme = "http"
port = HTTPPort
} else if port == "https" {
u.Scheme = "https"
port = HTTPSPort
}
return Address{Original: input, Scheme: u.Scheme, Host: host, Port: port, Path: u.Path}, err
}
// RegisterDevDirective splices name into the list of directives
// immediately before another directive. This function is ONLY
// for plugin development purposes! NEVER use it for a plugin
// that you are not currently building. If before is empty,
// the directive will be appended to the end of the list.
//
// It is imperative that directives execute in the proper
// order, and hard-coding the list of directives guarantees
// a correct, absolute order every time. This function is
// convenient when developing a plugin, but it does not
// guarantee absolute ordering. Multiple plugins registering
// directives with this function will lead to non-
// deterministic builds and buggy software.
//
// Directive names must be lower-cased and unique. Any errors
// here are fatal, and even successful calls print a message
// to stdout as a reminder to use it only in development.
func RegisterDevDirective(name, before string) {
if name == "" {
fmt.Println("[FATAL] Cannot register empty directive name")
os.Exit(1)
}
if strings.ToLower(name) != name {
fmt.Printf("[FATAL] %s: directive name must be lowercase\n", name)
os.Exit(1)
}
for _, dir := range directives {
if dir == name {
fmt.Printf("[FATAL] %s: directive name already exists\n", name)
os.Exit(1)
}
}
if before == "" {
directives = append(directives, name)
} else {
var found bool
for i, dir := range directives {
if dir == before {
directives = append(directives[:i], append([]string{name}, directives[i:]...)...)
found = true
break
}
}
if !found {
fmt.Printf("[FATAL] %s: directive not found\n", before)
os.Exit(1)
}
}
msg := fmt.Sprintf("Registered directive '%s' ", name)
if before == "" {
msg += "at end of list"
} else {
msg += fmt.Sprintf("before '%s'", before)
}
fmt.Printf("[DEV NOTICE] %s\n", msg)
}
// directives is the list of all directives known to exist for the
// http server type, including non-standard (3rd-party) directives.
// The ordering of this list is important.
var directives = []string{
// primitive actions that set up the fundamental vitals of each config
"root",
"index",
"bind",
"limits",
"timeouts",
"tls",
// services/utilities, or other directives that don't necessarily inject handlers
"startup", // TODO: Deprecate this directive
"shutdown", // TODO: Deprecate this directive
"on",
"supervisor", // github.com/lucaslorentz/caddy-supervisor
"request_id",
"realip", // github.com/captncraig/caddy-realip
"git", // github.com/abiosoft/caddy-git
// directives that add listener middleware to the stack
"proxyprotocol", // github.com/mastercactapus/caddy-proxyprotocol
// directives that add middleware to the stack
"locale", // github.com/simia-tech/caddy-locale
"log",
"cache", // github.com/nicolasazrak/caddy-cache
"rewrite",
"ext",
"gzip",
"header",
"geoip", // github.com/kodnaplakal/caddy-geoip
"errors",
"authz", // github.com/casbin/caddy-authz
"filter", // github.com/echocat/caddy-filter
"minify", // github.com/hacdias/caddy-minify
"ipfilter", // github.com/pyed/ipfilter
"ratelimit", // github.com/xuqingfeng/caddy-rate-limit
"expires", // github.com/epicagency/caddy-expires
"forwardproxy", // github.com/caddyserver/forwardproxy
"basicauth",
"redir",
"status",
"cors", // github.com/captncraig/cors/caddy
"nobots", // github.com/Xumeiquer/nobots
"mime",
"login", // github.com/tarent/loginsrv/caddy
"reauth", // github.com/freman/caddy-reauth
"jwt", // github.com/BTBurke/caddy-jwt
"jsonp", // github.com/pschlump/caddy-jsonp
"upload", // blitznote.com/src/caddy.upload
"multipass", // github.com/namsral/multipass/caddy
"internal",
"pprof",
"expvar",
"push",
"datadog", // github.com/payintech/caddy-datadog
"prometheus", // github.com/miekg/caddy-prometheus
"templates",
"proxy",
"fastcgi",
"cgi", // github.com/jung-kurt/caddy-cgi
"websocket",
"filemanager", // github.com/hacdias/filemanager/caddy/filemanager
"webdav", // github.com/hacdias/caddy-webdav
"markdown",
"browse",
"jekyll", // github.com/hacdias/filemanager/caddy/jekyll
"hugo", // github.com/hacdias/filemanager/caddy/hugo
"mailout", // github.com/SchumacherFM/mailout
"awses", // github.com/miquella/caddy-awses
"awslambda", // github.com/coopernurse/caddy-awslambda
"grpc", // github.com/pieterlouw/caddy-grpc
"gopkg", // github.com/zikes/gopkg
"restic", // github.com/restic/caddy
}
const (
// DefaultHost is the default host.
DefaultHost = ""
// DefaultPort is the default port.
DefaultPort = "2015"
// DefaultRoot is the default root folder.
DefaultRoot = "."
// DefaultHTTPPort is the default port for HTTP.
DefaultHTTPPort = "80"
// DefaultHTTPSPort is the default port for HTTPS.
DefaultHTTPSPort = "443"
)
// These "soft defaults" are configurable by
// command line flags, etc.
var (
// Root is the site root
Root = DefaultRoot
// Host is the site host
Host = DefaultHost
// Port is the site port
Port = DefaultPort
// GracefulTimeout is the maximum duration of a graceful shutdown.
GracefulTimeout time.Duration
// HTTP2 indicates whether HTTP2 is enabled or not.
HTTP2 bool
// QUIC indicates whether QUIC is enabled or not.
QUIC bool
// HTTPPort is the port to use for HTTP.
HTTPPort = DefaultHTTPPort
// HTTPSPort is the port to use for HTTPS.
HTTPSPort = DefaultHTTPSPort
)

View File

@@ -0,0 +1,349 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"strings"
"testing"
"sort"
"fmt"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
)
func TestStandardizeAddress(t *testing.T) {
for i, test := range []struct {
input string
scheme, host, port, path string
shouldErr bool
}{
{`localhost`, "", "localhost", "", "", false},
{`localhost:1234`, "", "localhost", "1234", "", false},
{`localhost:`, "", "localhost", "", "", false},
{`0.0.0.0`, "", "0.0.0.0", "", "", false},
{`127.0.0.1:1234`, "", "127.0.0.1", "1234", "", false},
{`:1234`, "", "", "1234", "", false},
{`[::1]`, "", "::1", "", "", false},
{`[::1]:1234`, "", "::1", "1234", "", false},
{`:`, "", "", "", "", false},
{`localhost:http`, "http", "localhost", "80", "", false},
{`localhost:https`, "https", "localhost", "443", "", false},
{`:http`, "http", "", "80", "", false},
{`:https`, "https", "", "443", "", false},
{`http://localhost:https`, "", "", "", "", true}, // conflict
{`http://localhost:http`, "", "", "", "", true}, // repeated scheme
{`http://localhost:443`, "", "", "", "", true}, // not conventional
{`https://localhost:80`, "", "", "", "", true}, // not conventional
{`http://localhost`, "http", "localhost", "80", "", false},
{`https://localhost`, "https", "localhost", "443", "", false},
{`http://127.0.0.1`, "http", "127.0.0.1", "80", "", false},
{`https://127.0.0.1`, "https", "127.0.0.1", "443", "", false},
{`http://[::1]`, "http", "::1", "80", "", false},
{`http://localhost:1234`, "http", "localhost", "1234", "", false},
{`https://127.0.0.1:1234`, "https", "127.0.0.1", "1234", "", false},
{`http://[::1]:1234`, "http", "::1", "1234", "", false},
{``, "", "", "", "", false},
{`::1`, "", "::1", "", "", true},
{`localhost::`, "", "localhost::", "", "", true},
{`#$%@`, "", "", "", "", true},
{`host/path`, "", "host", "", "/path", false},
{`http://host/`, "http", "host", "80", "/", false},
{`//asdf`, "", "asdf", "", "", false},
{`:1234/asdf`, "", "", "1234", "/asdf", false},
{`http://host/path`, "http", "host", "80", "/path", false},
{`https://host:443/path/foo`, "https", "host", "443", "/path/foo", false},
{`host:80/path`, "", "host", "80", "/path", false},
{`host:https/path`, "https", "host", "443", "/path", false},
{`/path`, "", "", "", "/path", false},
} {
actual, err := standardizeAddress(test.input)
if err != nil && !test.shouldErr {
t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err)
}
if err == nil && test.shouldErr {
t.Errorf("Test %d (%s): Expected error, but had none", i, test.input)
}
if !test.shouldErr && actual.Original != test.input {
t.Errorf("Test %d (%s): Expected original '%s', got '%s'", i, test.input, test.input, actual.Original)
}
if actual.Scheme != test.scheme {
t.Errorf("Test %d (%s): Expected scheme '%s', got '%s'", i, test.input, test.scheme, actual.Scheme)
}
if actual.Host != test.host {
t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host)
}
if actual.Port != test.port {
t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port)
}
if actual.Path != test.path {
t.Errorf("Test %d (%s): Expected path '%s', got '%s'", i, test.input, test.path, actual.Path)
}
}
}
func TestAddressVHost(t *testing.T) {
for i, test := range []struct {
addr Address
expected string
}{
{Address{Original: "host:1234"}, "host:1234"},
{Address{Original: "host:1234/foo"}, "host:1234/foo"},
{Address{Original: "host/foo"}, "host/foo"},
{Address{Original: "http://host/foo"}, "host/foo"},
{Address{Original: "https://host/foo"}, "host/foo"},
} {
actual := test.addr.VHost()
if actual != test.expected {
t.Errorf("Test %d: expected '%s' but got '%s'", i, test.expected, actual)
}
}
}
func TestAddressString(t *testing.T) {
for i, test := range []struct {
addr Address
expected string
}{
{Address{Scheme: "http", Host: "host", Port: "1234", Path: "/path"}, "http://host:1234/path"},
{Address{Scheme: "", Host: "host", Port: "", Path: ""}, "http://host"},
{Address{Scheme: "", Host: "host", Port: "80", Path: ""}, "http://host"},
{Address{Scheme: "", Host: "host", Port: "443", Path: ""}, "https://host"},
{Address{Scheme: "https", Host: "host", Port: "443", Path: ""}, "https://host"},
{Address{Scheme: "https", Host: "host", Port: "", Path: ""}, "https://host"},
{Address{Scheme: "", Host: "host", Port: "80", Path: "/path"}, "http://host/path"},
{Address{Scheme: "http", Host: "", Port: "1234", Path: ""}, "http://:1234"},
{Address{Scheme: "", Host: "", Port: "", Path: ""}, ""},
} {
actual := test.addr.String()
if actual != test.expected {
t.Errorf("Test %d: expected '%s' but got '%s'", i, test.expected, actual)
}
}
}
func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
Port = "9999"
filename := "Testfile"
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader(`localhost`)
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
t.Fatalf("Expected no error setting up test, got: %v", err)
}
_, err = ctx.InspectServerBlocks(filename, sblocks)
if err != nil {
t.Fatalf("Didn't expect an error, but got: %v", err)
}
localhostKey := "localhost"
item, ok := ctx.keysToSiteConfigs[localhostKey]
if !ok {
availableKeys := make(sort.StringSlice, len(ctx.keysToSiteConfigs))
i := 0
for key := range ctx.keysToSiteConfigs {
availableKeys[i] = fmt.Sprintf("'%s'", key)
i++
}
availableKeys.Sort()
t.Errorf("`%s` not found within registered keys, only these are available: %s", localhostKey, strings.Join(availableKeys, ", "))
return
}
addr := item.Addr
if addr.Port != Port {
t.Errorf("Expected the port on the address to be set, but got: %#v", addr)
}
}
// See discussion on PR #2015
func TestInspectServerBlocksWithAdjustedAddress(t *testing.T) {
Port = DefaultPort
Host = "example.com"
filename := "Testfile"
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("example.com {\n}\n:2015 {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
t.Fatalf("Expected no error setting up test, got: %v", err)
}
_, err = ctx.InspectServerBlocks(filename, sblocks)
if err == nil {
t.Fatalf("Expected an error because site definitions should overlap, got: %v", err)
}
}
func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
filename := "Testfile"
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
input := strings.NewReader("localhost {\n}\nLOCALHOST {\n}")
sblocks, err := caddyfile.Parse(filename, input, nil)
if err != nil {
t.Fatalf("Expected no error setting up test, got: %v", err)
}
_, err = ctx.InspectServerBlocks(filename, sblocks)
if err == nil {
t.Error("Expected an error because keys on this server type are case-insensitive (so these are duplicated), but didn't get an error")
}
}
func TestKeyNormalization(t *testing.T) {
originalCaseSensitivePath := CaseSensitivePath
defer func() {
CaseSensitivePath = originalCaseSensitivePath
}()
CaseSensitivePath = true
caseSensitiveData := []struct {
orig string
res string
}{
{
orig: "HTTP://A/ABCDEF",
res: "http://a/ABCDEF",
},
{
orig: "A/ABCDEF",
res: "a/ABCDEF",
},
{
orig: "A:2015/Port",
res: "a:2015/Port",
},
}
for _, item := range caseSensitiveData {
v := normalizedKey(item.orig)
if v != item.res {
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to true must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
}
}
CaseSensitivePath = false
caseInsensitiveData := []struct {
orig string
res string
}{
{
orig: "HTTP://A/ABCDEF",
res: "http://a/abcdef",
},
{
orig: "A/ABCDEF",
res: "a/abcdef",
},
{
orig: "A:2015/Port",
res: "a:2015/port",
},
}
for _, item := range caseInsensitiveData {
v := normalizedKey(item.orig)
if v != item.res {
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to false must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
}
}
}
func TestGetConfig(t *testing.T) {
// case insensitivity for key
con := caddy.NewTestController("http", "")
con.Key = "foo"
cfg := GetConfig(con)
con.Key = "FOO"
cfg2 := GetConfig(con)
if cfg != cfg2 {
t.Errorf("Expected same config using same key with different case; got %p and %p", cfg, cfg2)
}
// make sure different key returns different config
con.Key = "foobar"
cfg3 := GetConfig(con)
if cfg == cfg3 {
t.Errorf("Expected different configs using when key is different; got %p and %p", cfg, cfg3)
}
con.Key = "foo/foobar"
cfg4 := GetConfig(con)
con.Key = "foo/Foobar"
cfg5 := GetConfig(con)
if cfg4 == cfg5 {
t.Errorf("Expected different cases in path to differentiate keys in general")
}
}
func TestDirectivesList(t *testing.T) {
for i, dir1 := range directives {
if dir1 == "" {
t.Errorf("directives[%d]: empty directive name", i)
continue
}
if got, want := dir1, strings.ToLower(dir1); got != want {
t.Errorf("directives[%d]: %s should be lower-cased", i, dir1)
continue
}
for j := i + 1; j < len(directives); j++ {
dir2 := directives[j]
if dir1 == dir2 {
t.Errorf("directives[%d] (%s) is a duplicate of directives[%d] (%s)",
j, dir2, i, dir1)
}
}
}
}
func TestContextSaveConfig(t *testing.T) {
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("foo", new(SiteConfig))
if _, ok := ctx.keysToSiteConfigs["foo"]; !ok {
t.Error("Expected config to be saved, but it wasn't")
}
if got, want := len(ctx.siteConfigs), 1; got != want {
t.Errorf("Expected len(siteConfigs) == %d, but was %d", want, got)
}
ctx.saveConfig("Foobar", new(SiteConfig))
if _, ok := ctx.keysToSiteConfigs["foobar"]; ok {
t.Error("Did not expect to get config with case-insensitive key, but did")
}
if got, want := len(ctx.siteConfigs), 2; got != want {
t.Errorf("Expected len(siteConfigs) == %d, but was %d", want, got)
}
}
// Test to make sure we are correctly hiding the Caddyfile
func TestHideCaddyfile(t *testing.T) {
ctx := newContext(&caddy.Instance{Storage: make(map[interface{}]interface{})}).(*httpContext)
ctx.saveConfig("test", &SiteConfig{
Root: Root,
originCaddyfile: "Testfile",
})
err := hideCaddyfile(ctx)
if err != nil {
t.Fatalf("Failed to hide Caddyfile, got: %v", err)
return
}
if len(ctx.siteConfigs[0].HiddenFiles) == 0 {
t.Fatal("Failed to add Caddyfile to HiddenFiles.")
return
}
for _, file := range ctx.siteConfigs[0].HiddenFiles {
if file == "/Testfile" {
return
}
}
t.Fatal("Caddyfile missing from HiddenFiles")
}

View File

@@ -0,0 +1,261 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bytes"
"io"
"net/http"
"sync"
"time"
)
// ResponseRecorder is a type of http.ResponseWriter that captures
// the status code written to it and also the size of the body
// written in the response. A status code does not have
// to be written, however, in which case 200 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
//
// Setting the Replacer field allows middlewares to type-assert
// the http.ResponseWriter to ResponseRecorder and set their own
// placeholder values for logging utilities to use.
//
// Beware when accessing the Replacer value; it may be nil!
type ResponseRecorder struct {
*ResponseWriterWrapper
Replacer Replacer
status int
size int
start time.Time
}
// NewResponseRecorder makes and returns a new ResponseRecorder.
// Because a status is not set unless WriteHeader is called
// explicitly, this constructor initializes with a status code
// of 200 to cover the default case.
func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w},
status: http.StatusOK,
start: time.Now(),
}
}
// WriteHeader records the status code and calls the
// underlying ResponseWriter's WriteHeader method.
func (r *ResponseRecorder) WriteHeader(status int) {
r.status = status
r.ResponseWriterWrapper.WriteHeader(status)
}
// Write is a wrapper that records the size of the body
// that gets written.
func (r *ResponseRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriterWrapper.Write(buf)
if err == nil {
r.size += n
}
return n, err
}
// Size returns the size of the recorded response body.
func (r *ResponseRecorder) Size() int {
return r.size
}
// Status returns the recorded response status code.
func (r *ResponseRecorder) Status() int {
return r.status
}
// ResponseBuffer is a type that conditionally buffers the
// response in memory. It implements http.ResponseWriter so
// that it can stream the response if it is not buffering.
// Whether it buffers is decided by a func passed into the
// constructor, NewResponseBuffer.
//
// This type implements http.ResponseWriter, so you can pass
// this to the Next() middleware in the chain and record its
// response. However, since the entire response body will be
// buffered in memory, only use this when explicitly configured
// and required for some specific reason. For example, the
// text/template package only parses templates out of []byte
// and not io.Reader, so the templates directive uses this
// type to obtain the entire template text, but only on certain
// requests that match the right Content-Type, etc.
//
// ResponseBuffer also implements io.ReaderFrom for performance
// reasons. The standard lib's http.response type (unexported)
// uses io.Copy to write the body. io.Copy makes an allocation
// if the destination does not have a ReadFrom method (or if
// the source does not have a WriteTo method, but that's
// irrelevant here). Our ReadFrom is smart: if buffering, it
// calls the buffer's ReadFrom, which makes no allocs because
// it is already a buffer! If we're streaming the response
// instead, ReadFrom uses io.CopyBuffer with a pooled buffer
// that is managed within this package.
type ResponseBuffer struct {
*ResponseWriterWrapper
Buffer *bytes.Buffer
header http.Header
status int
shouldBuffer func(status int, header http.Header) bool
stream bool
rw http.ResponseWriter
wroteHeader bool
}
// NewResponseBuffer returns a new ResponseBuffer that will
// use buf to store the full body of the response if shouldBuffer
// returns true. If shouldBuffer returns false, then the response
// body will be streamed directly to rw.
//
// shouldBuffer will be passed the status code and header fields of
// the response. With that information, the function should decide
// whether to buffer the response in memory. For example: the templates
// directive uses this to determine whether the response is the
// right Content-Type (according to user config) for a template.
//
// For performance, the buf you pass in should probably be obtained
// from a sync.Pool in order to reuse allocated space.
func NewResponseBuffer(buf *bytes.Buffer, rw http.ResponseWriter,
shouldBuffer func(status int, header http.Header) bool) *ResponseBuffer {
rb := &ResponseBuffer{
Buffer: buf,
header: make(http.Header),
status: http.StatusOK, // default status code
shouldBuffer: shouldBuffer,
rw: rw,
}
rb.ResponseWriterWrapper = &ResponseWriterWrapper{ResponseWriter: rw}
return rb
}
// Header returns the response header map.
func (rb *ResponseBuffer) Header() http.Header {
return rb.header
}
// WriteHeader calls shouldBuffer to decide whether the
// upcoming body should be buffered, and then writes
// the header to the response.
func (rb *ResponseBuffer) WriteHeader(status int) {
if rb.wroteHeader {
return
}
rb.wroteHeader = true
rb.status = status
rb.stream = !rb.shouldBuffer(status, rb.header)
if rb.stream {
rb.CopyHeader()
rb.ResponseWriterWrapper.WriteHeader(status)
}
}
// Write writes buf to rb.Buffer if buffering, otherwise
// to the ResponseWriter directly if streaming.
func (rb *ResponseBuffer) Write(buf []byte) (int, error) {
if !rb.wroteHeader {
rb.WriteHeader(http.StatusOK)
}
if rb.stream {
return rb.ResponseWriterWrapper.Write(buf)
}
return rb.Buffer.Write(buf)
}
// Buffered returns whether rb has decided to buffer the response.
func (rb *ResponseBuffer) Buffered() bool {
return !rb.stream
}
// CopyHeader copies the buffered header in rb to the ResponseWriter,
// but it does not write the header out.
func (rb *ResponseBuffer) CopyHeader() {
for field, val := range rb.header {
rb.ResponseWriterWrapper.Header()[field] = val
}
}
// ReadFrom avoids allocations when writing to the buffer (if buffering),
// and reduces allocations when writing to the ResponseWriter directly
// (if streaming).
//
// In local testing with the templates directive, req/sec were improved
// from ~8,200 to ~9,600 on templated files by ensuring that this type
// implements io.ReaderFrom.
func (rb *ResponseBuffer) ReadFrom(src io.Reader) (int64, error) {
if !rb.wroteHeader {
rb.WriteHeader(http.StatusOK)
}
if rb.stream {
// first see if we can avoid any allocations at all
if wt, ok := src.(io.WriterTo); ok {
return wt.WriteTo(rb.ResponseWriterWrapper)
}
// if not, use a pooled copy buffer to reduce allocs
// (this improved req/sec from ~25,300 to ~27,000 on
// static files served directly with the fileserver,
// but results fluctuated a little on each run).
// a note of caution:
// https://go-review.googlesource.com/c/22134#message-ff351762308fe05f6b72a487d6842e3988916486
buf := respBufPool.Get().([]byte)
n, err := io.CopyBuffer(rb.ResponseWriterWrapper, src, buf)
respBufPool.Put(buf) // defer'ing this slowed down benchmarks a smidgin, I think
return n, err
}
return rb.Buffer.ReadFrom(src)
}
// StatusCodeWriter returns an http.ResponseWriter that always
// writes the status code stored in rb from when a response
// was buffered to it.
func (rb *ResponseBuffer) StatusCodeWriter(w http.ResponseWriter) http.ResponseWriter {
return forcedStatusCodeWriter{w, rb}
}
// forcedStatusCodeWriter is used to force a status code when
// writing the header. It uses the status code saved on rb.
// This is useful if passing a http.ResponseWriter into
// http.ServeContent because ServeContent hard-codes 2xx status
// codes. If we buffered the response, we force that status code
// instead.
type forcedStatusCodeWriter struct {
http.ResponseWriter
rb *ResponseBuffer
}
func (fscw forcedStatusCodeWriter) WriteHeader(int) {
fscw.ResponseWriter.WriteHeader(fscw.rb.status)
}
// respBufPool is used for io.CopyBuffer when ResponseBuffer
// is configured to stream a response.
var respBufPool = &sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
// Interface guards
var (
_ HTTPInterfaces = (*ResponseRecorder)(nil)
_ HTTPInterfaces = (*ResponseBuffer)(nil)
_ io.ReaderFrom = (*ResponseBuffer)(nil)
)

View File

@@ -0,0 +1,54 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestNewResponseRecorder(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
if !(recordRequest.ResponseWriter == w) {
t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
}
if recordRequest.status != http.StatusOK {
t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
}
}
func TestWriteHeader(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
recordRequest.WriteHeader(401)
if w.Code != 401 || recordRequest.status != 401 {
t.Fatalf("Expected Response status to be set to 401, but found %d\n", recordRequest.status)
}
}
func TestWrite(t *testing.T) {
w := httptest.NewRecorder()
responseTestString := "test"
recordRequest := NewResponseRecorder(w)
buf := []byte(responseTestString)
recordRequest.Write(buf)
if recordRequest.size != len(buf) {
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
}
if w.Body.String() != responseTestString {
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
}
}

View File

@@ -0,0 +1,471 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bytes"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"strconv"
"strings"
"time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddytls"
)
// requestReplacer is a strings.Replacer which is used to
// encode literal \r and \n characters and keep everything
// on one line
var requestReplacer = strings.NewReplacer(
"\r", "\\r",
"\n", "\\n",
)
var now = time.Now
// Replacer is a type which can replace placeholder
// substrings in a string with actual values from a
// http.Request and ResponseRecorder. Always use
// NewReplacer to get one of these. Any placeholders
// made with Set() should overwrite existing values if
// the key is already used.
type Replacer interface {
Replace(string) string
Set(key, value string)
}
// replacer implements Replacer. customReplacements
// is used to store custom replacements created with
// Set() until the time of replacement, at which point
// they will be used to overwrite other replacements
// if there is a name conflict.
type replacer struct {
customReplacements map[string]string
emptyValue string
responseRecorder *ResponseRecorder
request *http.Request
requestBody *limitWriter
}
type limitWriter struct {
w bytes.Buffer
remain int
}
func newLimitWriter(max int) *limitWriter {
return &limitWriter{
w: bytes.Buffer{},
remain: max,
}
}
func (lw *limitWriter) Write(p []byte) (int, error) {
// skip if we are full
if lw.remain <= 0 {
return len(p), nil
}
if n := len(p); n > lw.remain {
p = p[:lw.remain]
}
n, err := lw.w.Write(p)
lw.remain -= n
return n, err
}
func (lw *limitWriter) String() string {
return lw.w.String()
}
// NewReplacer makes a new replacer based on r and rr which
// are used for request and response placeholders, respectively.
// Request placeholders are created immediately, whereas
// response placeholders are not created until Replace()
// is invoked. rr may be nil if it is not available.
// emptyValue should be the string that is used in place
// of empty string (can still be empty string).
func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer {
repl := &replacer{
request: r,
responseRecorder: rr,
emptyValue: emptyValue,
}
// extract customReplacements from a request replacer when present.
if existing, ok := r.Context().Value(ReplacerCtxKey).(*replacer); ok {
repl.requestBody = existing.requestBody
repl.customReplacements = existing.customReplacements
} else {
// if there is no existing replacer, build one from scratch.
rb := newLimitWriter(MaxLogBodySize)
if r.Body != nil {
r.Body = struct {
io.Reader
io.Closer
}{io.TeeReader(r.Body, rb), io.Closer(r.Body)}
}
repl.requestBody = rb
repl.customReplacements = make(map[string]string)
}
return repl
}
func canLogRequest(r *http.Request) bool {
if r.Method == "POST" || r.Method == "PUT" {
for _, cType := range r.Header[headerContentType] {
// the cType could have charset and other info
if strings.Contains(cType, contentTypeJSON) || strings.Contains(cType, contentTypeXML) {
return true
}
}
}
return false
}
// unescapeBraces finds escaped braces in s and returns
// a string with those braces unescaped.
func unescapeBraces(s string) string {
s = strings.Replace(s, "\\{", "{", -1)
s = strings.Replace(s, "\\}", "}", -1)
return s
}
// Replace performs a replacement of values on s and returns
// the string with the replaced values.
func (r *replacer) Replace(s string) string {
// Do not attempt replacements if no placeholder is found.
if !strings.ContainsAny(s, "{}") {
return s
}
result := ""
Placeholders: // process each placeholder in sequence
for {
var idxStart, idxEnd int
idxOffset := 0
for { // find first unescaped opening brace
searchSpace := s[idxOffset:]
idxStart = strings.Index(searchSpace, "{")
if idxStart == -1 {
// no more placeholders
break Placeholders
}
if idxStart == 0 || searchSpace[idxStart-1] != '\\' {
// preceding character is not an escape
idxStart += idxOffset
break
}
// the brace we found was escaped
// search the rest of the string next
idxOffset += idxStart + 1
}
idxOffset = 0
for { // find first unescaped closing brace
searchSpace := s[idxStart+idxOffset:]
idxEnd = strings.Index(searchSpace, "}")
if idxEnd == -1 {
// unpaired placeholder
break Placeholders
}
if idxEnd == 0 || searchSpace[idxEnd-1] != '\\' {
// preceding character is not an escape
idxEnd += idxOffset + idxStart
break
}
// the brace we found was escaped
// search the rest of the string next
idxOffset += idxEnd + 1
}
// get a replacement for the unescaped placeholder
placeholder := unescapeBraces(s[idxStart : idxEnd+1])
replacement := r.getSubstitution(placeholder)
// append unescaped prefix + replacement
result += strings.TrimPrefix(unescapeBraces(s[:idxStart]), "\\") + replacement
// strip out scanned parts
s = s[idxEnd+1:]
}
// append unscanned parts
return result + unescapeBraces(s)
}
func roundDuration(d time.Duration) time.Duration {
if d >= time.Millisecond {
return round(d, time.Millisecond)
} else if d >= time.Microsecond {
return round(d, time.Microsecond)
}
return d
}
// round rounds d to the nearest r
func round(d, r time.Duration) time.Duration {
if r <= 0 {
return d
}
neg := d < 0
if neg {
d = -d
}
if m := d % r; m+m < r {
d = d - m
} else {
d = d + r - m
}
if neg {
return -d
}
return d
}
// getSubstitution retrieves value from corresponding key
func (r *replacer) getSubstitution(key string) string {
// search custom replacements first
if value, ok := r.customReplacements[key]; ok {
return value
}
// search request headers then
if key[1] == '>' {
want := key[2 : len(key)-1]
for key, values := range r.request.Header {
// Header placeholders (case-insensitive)
if strings.EqualFold(key, want) {
return strings.Join(values, ",")
}
}
}
// search response headers then
if r.responseRecorder != nil && key[1] == '<' {
want := key[2 : len(key)-1]
for key, values := range r.responseRecorder.Header() {
// Header placeholders (case-insensitive)
if strings.EqualFold(key, want) {
return strings.Join(values, ",")
}
}
}
// next check for cookies
if key[1] == '~' {
name := key[2 : len(key)-1]
if cookie, err := r.request.Cookie(name); err == nil {
return cookie.Value
}
}
// next check for query argument
if key[1] == '?' {
query := r.request.URL.Query()
name := key[2 : len(key)-1]
return query.Get(name)
}
// search default replacements in the end
switch key {
case "{method}":
return r.request.Method
case "{scheme}":
if r.request.TLS != nil {
return "https"
}
return "http"
case "{hostname}":
name, err := os.Hostname()
if err != nil {
return r.emptyValue
}
return name
case "{host}":
return r.request.Host
case "{hostonly}":
host, _, err := net.SplitHostPort(r.request.Host)
if err != nil {
return r.request.Host
}
return host
case "{path}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return u.Path
case "{path_escaped}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return url.QueryEscape(u.Path)
case "{request_id}":
reqid, _ := r.request.Context().Value(RequestIDCtxKey).(string)
return reqid
case "{rewrite_path}":
return r.request.URL.Path
case "{rewrite_path_escaped}":
return url.QueryEscape(r.request.URL.Path)
case "{query}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return u.RawQuery
case "{query_escaped}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return url.QueryEscape(u.RawQuery)
case "{fragment}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return u.Fragment
case "{proto}":
return r.request.Proto
case "{remote}":
host, _, err := net.SplitHostPort(r.request.RemoteAddr)
if err != nil {
return r.request.RemoteAddr
}
return host
case "{port}":
_, port, err := net.SplitHostPort(r.request.RemoteAddr)
if err != nil {
return r.emptyValue
}
return port
case "{uri}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return u.RequestURI()
case "{uri_escaped}":
u, _ := r.request.Context().Value(OriginalURLCtxKey).(url.URL)
return url.QueryEscape(u.RequestURI())
case "{rewrite_uri}":
return r.request.URL.RequestURI()
case "{rewrite_uri_escaped}":
return url.QueryEscape(r.request.URL.RequestURI())
case "{when}":
return now().Format(timeFormat)
case "{when_iso}":
return now().UTC().Format(timeFormatISOUTC)
case "{when_unix}":
return strconv.FormatInt(now().Unix(), 10)
case "{file}":
_, file := path.Split(r.request.URL.Path)
return file
case "{dir}":
dir, _ := path.Split(r.request.URL.Path)
return dir
case "{request}":
dump, err := httputil.DumpRequest(r.request, false)
if err != nil {
return r.emptyValue
}
return requestReplacer.Replace(string(dump))
case "{request_body}":
if !canLogRequest(r.request) {
return r.emptyValue
}
_, err := ioutil.ReadAll(r.request.Body)
if err != nil {
if err == ErrMaxBytesExceeded {
return r.emptyValue
}
}
return requestReplacer.Replace(r.requestBody.String())
case "{mitm}":
if val, ok := r.request.Context().Value(caddy.CtxKey("mitm")).(bool); ok {
if val {
return "likely"
}
return "unlikely"
}
return "unknown"
case "{status}":
if r.responseRecorder == nil {
return r.emptyValue
}
return strconv.Itoa(r.responseRecorder.status)
case "{size}":
if r.responseRecorder == nil {
return r.emptyValue
}
return strconv.Itoa(r.responseRecorder.size)
case "{latency}":
if r.responseRecorder == nil {
return r.emptyValue
}
return roundDuration(time.Since(r.responseRecorder.start)).String()
case "{latency_ms}":
if r.responseRecorder == nil {
return r.emptyValue
}
elapsedDuration := time.Since(r.responseRecorder.start)
return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10)
case "{tls_protocol}":
if r.request.TLS != nil {
for k, v := range caddytls.SupportedProtocols {
if v == r.request.TLS.Version {
return k
}
}
return "tls" // this should never happen, but guard in case
}
return r.emptyValue // because not using a secure channel
case "{tls_cipher}":
if r.request.TLS != nil {
for k, v := range caddytls.SupportedCiphersMap {
if v == r.request.TLS.CipherSuite {
return k
}
}
return "UNKNOWN" // this should never happen, but guard in case
}
return r.emptyValue
default:
// {labelN}
if strings.HasPrefix(key, "{label") {
nStr := key[6 : len(key)-1] // get the integer N in "{labelN}"
n, err := strconv.Atoi(nStr)
if err != nil || n < 1 {
return r.emptyValue
}
labels := strings.Split(r.request.Host, ".")
if n > len(labels) {
return r.emptyValue
}
return labels[n-1]
}
}
return r.emptyValue
}
// convertToMilliseconds returns the number of milliseconds in the given duration
func convertToMilliseconds(d time.Duration) int64 {
return d.Nanoseconds() / 1e6
}
// Set sets key to value in the r.customReplacements map.
func (r *replacer) Set(key, value string) {
r.customReplacements["{"+key+"}"] = value
}
const (
timeFormat = "02/Jan/2006:15:04:05 -0700"
timeFormatISOUTC = "2006-01-02T15:04:05Z" // ISO 8601 with timezone to be assumed as UTC
headerContentType = "Content-Type"
contentTypeJSON = "application/json"
contentTypeXML = "application/xml"
// MaxLogBodySize limits the size of logged request's body
MaxLogBodySize = 100 * 1024
)

View File

@@ -0,0 +1,352 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)
func TestNewReplacer(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost", reader)
if err != nil {
t.Fatal("Request Formation Failed\n")
}
rep := NewReplacer(request, recordRequest, "")
switch v := rep.(type) {
case *replacer:
if v.getSubstitution("{host}") != "localhost" {
t.Error("Expected host to be localhost")
}
if v.getSubstitution("{method}") != "POST" {
t.Error("Expected request method to be POST")
}
default:
t.Fatalf("Expected *replacer underlying Replacer type, got: %#v", rep)
}
}
func TestReplace(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost.local/?foo=bar", reader)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
hostname, err := os.Hostname()
if err != nil {
t.Fatalf("Failed to determine hostname: %v", err)
}
old := now
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
defer func() {
now = old
}()
testCases := []struct {
template string
expect string
}{
{"This hostname is {hostname}", "This hostname is " + hostname},
{"This host is {host}.", "This host is localhost.local."},
{"This request method is {method}.", "This request method is POST."},
{"The response status is {status}.", "The response status is 200."},
{"{when}", "02/Jan/2006:15:04:05 +0000"},
{"{when_iso}", "2006-01-02T15:04:12Z"},
{"{when_unix}", "1136214252"},
{"The Custom header is {>Custom}.", "The Custom header is foobarbaz."},
{"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."},
{"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost.local\\r\\n" +
"Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" +
"Shorterval: 1\\r\\n\\r\\n."},
{"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."},
{"The cUsToM response header is {<CuSTom}.", "The cUsToM response header is CustomResponseHeader."},
{"The Non-Existent header is {>Non-Existent}.", "The Non-Existent header is -."},
{"Bad {host placeholder...", "Bad {host placeholder..."},
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
{"Bad {>Custom placeholder {>ShorterVal}", "Bad -"},
{"Bad {}", "Bad -"},
{"Cookies are {~taste}", "Cookies are delicious"},
{"Missing cookie is {~missing}", "Missing cookie is -"},
{"Query string is {query}", "Query string is foo=bar"},
{"Query string value for foo is {?foo}", "Query string value for foo is bar"},
{"Missing query string argument is {?missing}", "Missing query string argument is "},
{"{label1} {label2} {label3} {label4}", "localhost local - -"},
{"Label with missing number is {label} or {labelQQ}", "Label with missing number is - or -"},
{"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"},
}
for _, c := range testCases {
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
}
}
complexCases := []struct {
template string
replacements map[string]string
expect string
}{
{
"/a{1}/{2}",
map[string]string{
"{1}": "12",
"{2}": "",
},
"/a12/"},
}
for _, c := range complexCases {
repl := &replacer{
customReplacements: c.replacements,
}
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
}
}
}
func BenchmarkReplace(b *testing.B) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
b.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
repl.Replace("This hostname is {hostname}")
}
}
func BenchmarkReplaceEscaped(b *testing.B) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
b.Fatalf("Failed to make request: %v", err)
}
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
request = request.WithContext(ctx)
request.Header.Set("Custom", "foobarbaz")
request.Header.Set("ShorterVal", "1")
repl := NewReplacer(request, recordRequest, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
// add some respons headers
recordRequest.Header().Set("Custom", "CustomResponseHeader")
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
repl.Replace("\\{ 'hostname': '{hostname}' \\}")
}
}
func TestResponseRecorderNil(t *testing.T) {
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
request.Header.Set("Custom", "foobarbaz")
repl := NewReplacer(request, nil, "-")
// add some headers after creating replacer
request.Header.Set("CustomAdd", "caddy")
request.Header.Set("Cookie", "foo=bar; taste=delicious")
old := now
now = func() time.Time {
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
}
defer func() {
now = old
}()
testCases := []struct {
template string
expect string
}{
{"The Custom response header is {<Custom}.", "The Custom response header is -."},
}
for _, c := range testCases {
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
}
}
}
func TestSet(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://localhost", reader)
if err != nil {
t.Fatalf("Request Formation Failed: %s\n", err.Error())
}
repl := NewReplacer(request, recordRequest, "")
repl.Set("host", "getcaddy.com")
repl.Set("method", "GET")
repl.Set("status", "201")
repl.Set("variable", "value")
if repl.Replace("This host is {host}") != "This host is getcaddy.com" {
t.Error("Expected host replacement failed")
}
if repl.Replace("This request method is {method}") != "This request method is GET" {
t.Error("Expected method replacement failed")
}
if repl.Replace("The response status is {status}") != "The response status is 201" {
t.Error("Expected status replacement failed")
}
if repl.Replace("The value of variable is {variable}") != "The value of variable is value" {
t.Error("Expected variable replacement failed")
}
}
// Test function to test that various placeholders hold correct values after a rewrite
// has been performed. The NewRequest actually contains the rewritten value.
func TestPathRewrite(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
reader := strings.NewReader(`{"username": "dennis"}`)
request, err := http.NewRequest("POST", "http://getcaddy.com/index.php?key=value", reader)
if err != nil {
t.Fatalf("Request Formation Failed: %s\n", err.Error())
}
urlCopy := *request.URL
urlCopy.Path = "a/custom/path.php"
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, urlCopy)
request = request.WithContext(ctx)
repl := NewReplacer(request, recordRequest, "")
if got, want := repl.Replace("This path is '{path}'"), "This path is 'a/custom/path.php'"; got != want {
t.Errorf("{path} replacement failed; got '%s', want '%s'", got, want)
}
if got, want := repl.Replace("This path is {rewrite_path}"), "This path is /index.php"; got != want {
t.Errorf("{rewrite_path} replacement failed; got '%s', want '%s'", got, want)
}
if got, want := repl.Replace("This path is '{uri}'"), "This path is 'a/custom/path.php?key=value'"; got != want {
t.Errorf("{uri} replacement failed; got '%s', want '%s'", got, want)
}
if got, want := repl.Replace("This path is {rewrite_uri}"), "This path is /index.php?key=value"; got != want {
t.Errorf("{rewrite_uri} replacement failed; got '%s', want '%s'", got, want)
}
}
func TestRound(t *testing.T) {
var tests = map[time.Duration]time.Duration{
// 599.935µs -> 560µs
559935 * time.Nanosecond: 560 * time.Microsecond,
// 1.55ms -> 2ms
1550 * time.Microsecond: 2 * time.Millisecond,
// 1.5555s -> 1.556s
1555500 * time.Microsecond: 1556 * time.Millisecond,
// 1m2.0035s -> 1m2.004s
62003500 * time.Microsecond: 62004 * time.Millisecond,
}
for dur, expected := range tests {
rounded := roundDuration(dur)
if rounded != expected {
t.Errorf("Expected %v, Got %v", expected, rounded)
}
}
}
func TestMillisecondConverstion(t *testing.T) {
var testCases = map[time.Duration]int64{
2 * time.Second: 2000,
9039492 * time.Nanosecond: 9,
1000 * time.Microsecond: 1,
127 * time.Nanosecond: 0,
0 * time.Millisecond: 0,
255 * time.Millisecond: 255,
}
for dur, expected := range testCases {
numMillisecond := convertToMilliseconds(dur)
if numMillisecond != expected {
t.Errorf("Expected %v. Got %v", expected, numMillisecond)
}
}
}

View File

@@ -0,0 +1,79 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bufio"
"net"
"net/http"
)
// ResponseWriterWrapper wrappers underlying ResponseWriter
// and inherits its Hijacker/Pusher/CloseNotifier/Flusher as well.
type ResponseWriterWrapper struct {
http.ResponseWriter
}
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (rww *ResponseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rww.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, NonHijackerError{Underlying: rww.ResponseWriter}
}
// Flush implements http.Flusher. It simply wraps the underlying
// ResponseWriter's Flush method if there is one, or panics.
func (rww *ResponseWriterWrapper) Flush() {
if f, ok := rww.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else {
panic(NonFlusherError{Underlying: rww.ResponseWriter})
}
}
// CloseNotify implements http.CloseNotifier.
// It just inherits the underlying ResponseWriter's CloseNotify method.
// It panics if the underlying ResponseWriter is not a CloseNotifier.
func (rww *ResponseWriterWrapper) CloseNotify() <-chan bool {
if cn, ok := rww.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
panic(NonCloseNotifierError{Underlying: rww.ResponseWriter})
}
// Push implements http.Pusher.
// It just inherits the underlying ResponseWriter's Push method.
// It panics if the underlying ResponseWriter is not a Pusher.
func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) error {
if pusher, hasPusher := rww.ResponseWriter.(http.Pusher); hasPusher {
return pusher.Push(target, opts)
}
return NonPusherError{Underlying: rww.ResponseWriter}
}
// HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support.
type HTTPInterfaces interface {
http.ResponseWriter
http.Pusher
http.Flusher
http.CloseNotifier
http.Hijacker
}
// Interface guards
var _ HTTPInterfaces = (*ResponseWriterWrapper)(nil)

View File

@@ -0,0 +1,138 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"errors"
"io"
"path/filepath"
"strconv"
"gopkg.in/natefinch/lumberjack.v2"
)
// LogRoller implements a type that provides a rolling logger.
type LogRoller struct {
Filename string
MaxSize int
MaxAge int
MaxBackups int
Compress bool
LocalTime bool
}
// GetLogWriter returns an io.Writer that writes to a rolling logger.
// This should be called only from the main goroutine (like during
// server setup) because this method is not thread-safe; it is careful
// to create only one log writer per log file, even if the log file
// is shared by different sites or middlewares. This ensures that
// rolling is synchronized, since a process (or multiple processes)
// should not create more than one roller on the same file at the
// same time. See issue #1363.
func (l LogRoller) GetLogWriter() io.Writer {
absPath, err := filepath.Abs(l.Filename)
if err != nil {
absPath = l.Filename // oh well, hopefully they're consistent in how they specify the filename
}
lj, has := lumberjacks[absPath]
if !has {
lj = &lumberjack.Logger{
Filename: l.Filename,
MaxSize: l.MaxSize,
MaxAge: l.MaxAge,
MaxBackups: l.MaxBackups,
Compress: l.Compress,
LocalTime: l.LocalTime,
}
lumberjacks[absPath] = lj
}
return lj
}
// IsLogRollerSubdirective is true if the subdirective is for the log roller.
func IsLogRollerSubdirective(subdir string) bool {
return subdir == directiveRotateSize ||
subdir == directiveRotateAge ||
subdir == directiveRotateKeep ||
subdir == directiveRotateCompress
}
var invalidRollerParameterErr = errors.New("invalid roller parameter")
// ParseRoller parses roller contents out of c.
func ParseRoller(l *LogRoller, what string, where ...string) error {
if l == nil {
l = DefaultLogRoller()
}
// rotate_compress doesn't accept any parameters.
// others only accept one parameter
if (what == directiveRotateCompress && len(where) != 0) ||
(what != directiveRotateCompress && len(where) != 1) {
return invalidRollerParameterErr
}
var (
value int
err error
)
if what != directiveRotateCompress {
value, err = strconv.Atoi(where[0])
if err != nil {
return err
}
}
switch what {
case directiveRotateSize:
l.MaxSize = value
case directiveRotateAge:
l.MaxAge = value
case directiveRotateKeep:
l.MaxBackups = value
case directiveRotateCompress:
l.Compress = true
}
return nil
}
// DefaultLogRoller will roll logs by default.
func DefaultLogRoller() *LogRoller {
return &LogRoller{
MaxSize: defaultRotateSize,
MaxAge: defaultRotateAge,
MaxBackups: defaultRotateKeep,
Compress: false,
LocalTime: true,
}
}
const (
// defaultRotateSize is 100 MB.
defaultRotateSize = 100
// defaultRotateAge is 14 days.
defaultRotateAge = 14
// defaultRotateKeep is 10 files.
defaultRotateKeep = 10
directiveRotateSize = "rotate_size"
directiveRotateAge = "rotate_age"
directiveRotateKeep = "rotate_keep"
directiveRotateCompress = "rotate_compress"
)
// lumberjacks maps log filenames to the logger
// that is being used to keep them rolled/maintained.
var lumberjacks = make(map[string]*lumberjack.Logger)

View File

@@ -0,0 +1,589 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package httpserver implements an HTTP server on top of Caddy.
package httpserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/staticfiles"
"github.com/mholt/caddy/caddytls"
"github.com/mholt/caddy/telemetry"
)
// Server is the HTTP server implementation.
type Server struct {
Server *http.Server
quicServer *h2quic.Server
listener net.Listener
listenerMu sync.Mutex
sites []*SiteConfig
connTimeout time.Duration // max time to wait for a connection before force stop
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie
}
// ensure it satisfies the interface
var _ caddy.GracefulServer = new(Server)
var defaultALPN = []string{"h2", "http/1.1"}
// makeTLSConfig extracts TLS settings from each site config to
// build a tls.Config usable in Caddy HTTP servers. The returned
// config will be nil if TLS is disabled for these sites.
func makeTLSConfig(group []*SiteConfig) (*tls.Config, error) {
var tlsConfigs []*caddytls.Config
for i := range group {
if HTTP2 && len(group[i].TLS.ALPN) == 0 {
// if no application-level protocol was configured up to now,
// default to HTTP/2, then HTTP/1.1 if necessary
group[i].TLS.ALPN = defaultALPN
}
tlsConfigs = append(tlsConfigs, group[i].TLS)
}
return caddytls.MakeTLSConfig(tlsConfigs)
}
func getFallbacks(sites []*SiteConfig) []string {
fallbacks := []string{}
for _, sc := range sites {
if sc.FallbackSite {
fallbacks = append(fallbacks, sc.Addr.Host)
}
}
return fallbacks
}
// NewServer creates a new Server instance that will listen on addr
// and will serve the sites configured in group.
func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s := &Server{
Server: makeHTTPServerWithTimeouts(addr, group),
vhosts: newVHostTrie(),
sites: group,
connTimeout: GracefulTimeout,
}
s.vhosts.fallbackHosts = append(s.vhosts.fallbackHosts, getFallbacks(group)...)
s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
s.Server.Handler = s // this is weird, but whatever
// extract TLS settings from each site config to build
// a tls.Config, which will not be nil if TLS is enabled
tlsConfig, err := makeTLSConfig(group)
if err != nil {
return nil, err
}
s.Server.TLSConfig = tlsConfig
// if TLS is enabled, make sure we prepare the Server accordingly
if s.Server.TLSConfig != nil {
// enable QUIC if desired (requires HTTP/2)
if HTTP2 && QUIC {
s.quicServer = &h2quic.Server{Server: s.Server}
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
}
// wrap the HTTP handler with a handler that does MITM detection
tlsh := &tlsHandler{next: s.Server.Handler}
s.Server.Handler = tlsh // this needs to be the "outer" handler when Serve() is called, for type assertion
// when Serve() creates the TLS listener later, that listener should
// be adding a reference the ClientHello info to a map; this callback
// will be sure to clear out that entry when the connection closes.
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
// when a connection closes or is hijacked, delete its entry
// in the map, because we are done with it.
if tlsh.listener != nil {
if cs == http.StateHijacked || cs == http.StateClosed {
tlsh.listener.helloInfosMu.Lock()
delete(tlsh.listener.helloInfos, c.RemoteAddr().String())
tlsh.listener.helloInfosMu.Unlock()
}
}
}
// As of Go 1.7, if the Server's TLSConfig is not nil, HTTP/2 is enabled only
// if TLSConfig.NextProtos includes the string "h2"
if HTTP2 && len(s.Server.TLSConfig.NextProtos) == 0 {
// some experimenting shows that this NextProtos must have at least
// one value that overlaps with the NextProtos of any other tls.Config
// that is returned from GetConfigForClient; if there is no overlap,
// the connection will fail (as of Go 1.8, Feb. 2017).
s.Server.TLSConfig.NextProtos = defaultALPN
}
}
// Compile custom middleware for every site (enables virtual hosting)
for _, site := range group {
stack := Handler(staticfiles.FileServer{Root: http.Dir(site.Root), Hide: site.HiddenFiles, IndexPages: site.IndexPages})
for i := len(site.middleware) - 1; i >= 0; i-- {
stack = site.middleware[i](stack)
}
site.middlewareChain = stack
s.vhosts.Insert(site.Addr.VHost(), site)
}
return s, nil
}
// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
var min int64
for _, cfg := range group {
limit := cfg.Limits.MaxRequestHeaderSize
if limit == 0 {
continue
}
// not set yet
if min == 0 {
min = limit
}
// find a better one
if limit < min {
min = limit
}
}
if min > 0 {
s.MaxHeaderBytes = int(min)
}
return s
}
// makeHTTPServerWithTimeouts makes an http.Server from the group of
// configs in a way that configures timeouts (or, if not set, it uses
// the default timeouts) by combining the configuration of each
// SiteConfig in the group. (Timeouts are important for mitigating
// slowloris attacks.)
func makeHTTPServerWithTimeouts(addr string, group []*SiteConfig) *http.Server {
// find the minimum duration configured for each timeout
var min Timeouts
for _, cfg := range group {
if cfg.Timeouts.ReadTimeoutSet &&
(!min.ReadTimeoutSet || cfg.Timeouts.ReadTimeout < min.ReadTimeout) {
min.ReadTimeoutSet = true
min.ReadTimeout = cfg.Timeouts.ReadTimeout
}
if cfg.Timeouts.ReadHeaderTimeoutSet &&
(!min.ReadHeaderTimeoutSet || cfg.Timeouts.ReadHeaderTimeout < min.ReadHeaderTimeout) {
min.ReadHeaderTimeoutSet = true
min.ReadHeaderTimeout = cfg.Timeouts.ReadHeaderTimeout
}
if cfg.Timeouts.WriteTimeoutSet &&
(!min.WriteTimeoutSet || cfg.Timeouts.WriteTimeout < min.WriteTimeout) {
min.WriteTimeoutSet = true
min.WriteTimeout = cfg.Timeouts.WriteTimeout
}
if cfg.Timeouts.IdleTimeoutSet &&
(!min.IdleTimeoutSet || cfg.Timeouts.IdleTimeout < min.IdleTimeout) {
min.IdleTimeoutSet = true
min.IdleTimeout = cfg.Timeouts.IdleTimeout
}
}
// for the values that were not set, use defaults
if !min.ReadTimeoutSet {
min.ReadTimeout = defaultTimeouts.ReadTimeout
}
if !min.ReadHeaderTimeoutSet {
min.ReadHeaderTimeout = defaultTimeouts.ReadHeaderTimeout
}
if !min.WriteTimeoutSet {
min.WriteTimeout = defaultTimeouts.WriteTimeout
}
if !min.IdleTimeoutSet {
min.IdleTimeout = defaultTimeouts.IdleTimeout
}
// set the final values on the server and return it
return &http.Server{
Addr: addr,
ReadTimeout: min.ReadTimeout,
ReadHeaderTimeout: min.ReadHeaderTimeout,
WriteTimeout: min.WriteTimeout,
IdleTimeout: min.IdleTimeout,
}
}
func (s *Server) wrapWithSvcHeaders(previousHandler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.quicServer.SetQuicHeaders(w.Header())
previousHandler.ServeHTTP(w, r)
}
}
// Listen creates an active listener for s that can be
// used to serve requests.
func (s *Server) Listen() (net.Listener, error) {
if s.Server == nil {
return nil, fmt.Errorf("Server field is nil")
}
ln, err := net.Listen("tcp", s.Server.Addr)
if err != nil {
var succeeded bool
if runtime.GOOS == "windows" {
// Windows has been known to keep sockets open even after closing the listeners.
// Tests reveal this error case easily because they call Start() then Stop()
// in succession. TODO: Better way to handle this? And why limit this to Windows?
for i := 0; i < 20; i++ {
time.Sleep(100 * time.Millisecond)
ln, err = net.Listen("tcp", s.Server.Addr)
if err == nil {
succeeded = true
break
}
}
}
if !succeeded {
return nil, err
}
}
if tcpLn, ok := ln.(*net.TCPListener); ok {
ln = tcpKeepAliveListener{TCPListener: tcpLn}
}
cln := ln.(caddy.Listener)
for _, site := range s.sites {
for _, m := range site.listenerMiddleware {
cln = m(cln)
}
}
// Very important to return a concrete caddy.Listener
// implementation for graceful restarts.
return cln.(caddy.Listener), nil
}
// ListenPacket creates udp connection for QUIC if it is enabled,
func (s *Server) ListenPacket() (net.PacketConn, error) {
if QUIC {
udpAddr, err := net.ResolveUDPAddr("udp", s.Server.Addr)
if err != nil {
return nil, err
}
return net.ListenUDP("udp", udpAddr)
}
return nil, nil
}
// Serve serves requests on ln. It blocks until ln is closed.
func (s *Server) Serve(ln net.Listener) error {
s.listenerMu.Lock()
s.listener = ln
s.listenerMu.Unlock()
if s.Server.TLSConfig != nil {
// Create TLS listener - note that we do not replace s.listener
// with this TLS listener; tls.listener is unexported and does
// not implement the File() method we need for graceful restarts
// on POSIX systems.
// TODO: Is this ^ still relevant anymore? Maybe we can now that it's a net.Listener...
ln = newTLSListener(ln, s.Server.TLSConfig)
if handler, ok := s.Server.Handler.(*tlsHandler); ok {
handler.listener = ln.(*tlsHelloListener)
}
// Rotate TLS session ticket keys
s.tlsGovChan = caddytls.RotateSessionTicketKeys(s.Server.TLSConfig)
}
err := s.Server.Serve(ln)
if err == http.ErrServerClosed {
err = nil // not an error worth reporting since closing a server is intentional
}
if s.quicServer != nil {
s.quicServer.Close()
}
return err
}
// ServePacket serves QUIC requests on pc until it is closed.
func (s *Server) ServePacket(pc net.PacketConn) error {
if s.quicServer != nil {
err := s.quicServer.Serve(pc.(*net.UDPConn))
return fmt.Errorf("serving QUIC connections: %v", err)
}
return nil
}
// ServeHTTP is the entry point of all HTTP requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
// We absolutely need to be sure we stay alive up here,
// even though, in theory, the errors middleware does this.
if rec := recover(); rec != nil {
log.Printf("[PANIC] %v", rec)
DefaultErrorFunc(w, r, http.StatusInternalServerError)
}
}()
// record the User-Agent string (with a cap on its length to mitigate attacks)
ua := r.Header.Get("User-Agent")
if len(ua) > 512 {
ua = ua[:512]
}
uaHash := telemetry.FastHash([]byte(ua)) // this is a normalized field
go telemetry.SetNested("http_user_agent", uaHash, ua)
go telemetry.AppendUnique("http_user_agent_count", uaHash)
go telemetry.Increment("http_request_count")
// copy the original, unchanged URL into the context
// so it can be referenced by middlewares
urlCopy := *r.URL
if r.URL.User != nil {
userInfo := new(url.Userinfo)
*userInfo = *r.URL.User
urlCopy.User = userInfo
}
c := context.WithValue(r.Context(), OriginalURLCtxKey, urlCopy)
r = r.WithContext(c)
// Setup a replacer for the request that keeps track of placeholder
// values across plugins.
replacer := NewReplacer(r, nil, "")
c = context.WithValue(r.Context(), ReplacerCtxKey, replacer)
r = r.WithContext(c)
w.Header().Set("Server", caddy.AppName)
status, _ := s.serveHTTP(w, r)
// Fallback error response in case error handling wasn't chained in
if status >= 400 {
DefaultErrorFunc(w, r, status)
}
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// strip out the port because it's not used in virtual
// hosting; the port is irrelevant because each listener
// is on a different port.
hostname, _, err := net.SplitHostPort(r.Host)
if err != nil {
hostname = r.Host
}
// look up the virtualhost; if no match, serve error
vhost, pathPrefix := s.vhosts.Match(hostname + r.URL.Path)
c := context.WithValue(r.Context(), caddy.CtxKey("path_prefix"), pathPrefix)
r = r.WithContext(c)
if vhost == nil {
// check for ACME challenge even if vhost is nil;
// could be a new host coming online soon
if caddytls.HTTPChallengeHandler(w, r, "localhost") {
return 0, nil
}
// otherwise, log the error and write a message to the client
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteHost = r.RemoteAddr
}
WriteSiteNotFound(w, r) // don't add headers outside of this function
log.Printf("[INFO] %s - No such site at %s (Remote: %s, Referer: %s)",
hostname, s.Server.Addr, remoteHost, r.Header.Get("Referer"))
return 0, nil
}
// we still check for ACME challenge if the vhost exists,
// because we must apply its HTTP challenge config settings
if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
return 0, nil
}
// trim the path portion of the site address from the beginning of
// the URL path, so a request to example.com/foo/blog on the site
// defined as example.com/foo appears as /blog instead of /foo/blog.
if pathPrefix != "/" {
r.URL = trimPathPrefix(r.URL, pathPrefix)
}
// enforce strict host matching, which ensures that the SNI
// value (if any), matches the Host header; essential for
// sites that rely on TLS ClientAuth sharing a port with
// sites that do not - if mismatched, close the connection
if vhost.StrictHostMatching && r.TLS != nil &&
strings.ToLower(r.TLS.ServerName) != strings.ToLower(hostname) {
r.Close = true
log.Printf("[ERROR] %s - strict host matching: SNI (%s) and HTTP Host (%s) values differ",
vhost.Addr, r.TLS.ServerName, hostname)
return http.StatusForbidden, nil
}
return vhost.middlewareChain.ServeHTTP(w, r)
}
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
// We need to use URL.EscapedPath() when trimming the pathPrefix as
// URL.Path is ambiguous about / or %2f - see docs. See #1927
trimmedPath := strings.TrimPrefix(u.EscapedPath(), prefix)
if !strings.HasPrefix(trimmedPath, "/") {
trimmedPath = "/" + trimmedPath
}
// After trimming path reconstruct uri string with Query before parsing
trimmedURI := trimmedPath
if u.RawQuery != "" || u.ForceQuery == true {
trimmedURI = trimmedPath + "?" + u.RawQuery
}
if u.Fragment != "" {
trimmedURI = trimmedURI + "#" + u.Fragment
}
trimmedURL, err := url.Parse(trimmedURI)
if err != nil {
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmedURI, err)
return u
}
return trimmedURL
}
// Address returns the address s was assigned to listen on.
func (s *Server) Address() string {
return s.Server.Addr
}
// Stop stops s gracefully (or forcefully after timeout) and
// closes its listener.
func (s *Server) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), s.connTimeout)
defer cancel()
err := s.Server.Shutdown(ctx)
if err != nil {
return err
}
// signal any TLS governor goroutines to exit
if s.tlsGovChan != nil {
close(s.tlsGovChan)
}
return nil
}
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming caddy.Quiet == false.
func (s *Server) OnStartupComplete() {
if caddy.Quiet {
return
}
for _, site := range s.sites {
output := site.Addr.String()
if caddy.IsLoopback(s.Address()) && !caddy.IsLoopback(site.Addr.Host) {
output += " (only accessible on this machine)"
}
fmt.Println(output)
log.Println(output)
}
}
// defaultTimeouts stores the default timeout values to use
// if left unset by user configuration. NOTE: Most default
// timeouts are disabled (see issues #1464 and #1733).
var defaultTimeouts = Timeouts{IdleTimeout: 5 * time.Minute}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
//
// Borrowed from the Go standard library.
type tcpKeepAliveListener struct {
*net.TCPListener
}
// Accept accepts the connection with a keep-alive enabled.
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
// File implements caddy.Listener; it returns the underlying file of the listener.
func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File()
}
// ErrMaxBytesExceeded is the error returned by MaxBytesReader
// when the request body exceeds the limit imposed
var ErrMaxBytesExceeded = errors.New("http: request body too large")
// DefaultErrorFunc responds to an HTTP request with a simple description
// of the specified HTTP status code.
func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) {
WriteTextResponse(w, status, fmt.Sprintf("%d %s\n", status, http.StatusText(status)))
}
const httpStatusMisdirectedRequest = 421 // RFC 7540, 9.1.2
// WriteSiteNotFound writes appropriate error code to w, signaling that
// requested host is not served by Caddy on a given port.
func WriteSiteNotFound(w http.ResponseWriter, r *http.Request) {
status := http.StatusNotFound
if r.ProtoMajor >= 2 {
// TODO: use http.StatusMisdirectedRequest when it gets defined
status = httpStatusMisdirectedRequest
}
WriteTextResponse(w, status, fmt.Sprintf("%d Site %s is not served on this interface\n", status, r.Host))
}
// WriteTextResponse writes body with code status to w. The body will
// be interpreted as plain text.
func WriteTextResponse(w http.ResponseWriter, status int, body string) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(status)
w.Write([]byte(body))
}
// SafePath joins siteRoot and reqPath and converts it to a path that can
// be used to access a path on the local disk. It ensures the path does
// not traverse outside of the site root.
//
// If opening a file, use http.Dir instead.
func SafePath(siteRoot, reqPath string) string {
reqPath = filepath.ToSlash(reqPath)
reqPath = strings.Replace(reqPath, "\x00", "", -1) // NOTE: Go 1.9 checks for null bytes in the syscall package
if siteRoot == "" {
siteRoot = "."
}
return filepath.Join(siteRoot, filepath.FromSlash(path.Clean("/"+reqPath)))
}
// OriginalURLCtxKey is the key for accessing the original, incoming URL on an HTTP request.
const OriginalURLCtxKey = caddy.CtxKey("original_url")

View File

@@ -0,0 +1,269 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"net/http"
"net/url"
"testing"
"time"
)
func TestAddress(t *testing.T) {
addr := "127.0.0.1:9005"
srv := &Server{Server: &http.Server{Addr: addr}}
if got, want := srv.Address(), addr; got != want {
t.Errorf("Expected '%s' but got '%s'", want, got)
}
}
func TestMakeHTTPServerWithTimeouts(t *testing.T) {
for i, tc := range []struct {
group []*SiteConfig
expected Timeouts
}{
{
group: []*SiteConfig{{Timeouts: Timeouts{}}},
expected: Timeouts{
ReadTimeout: defaultTimeouts.ReadTimeout,
ReadHeaderTimeout: defaultTimeouts.ReadHeaderTimeout,
WriteTimeout: defaultTimeouts.WriteTimeout,
IdleTimeout: defaultTimeouts.IdleTimeout,
},
},
{
group: []*SiteConfig{{Timeouts: Timeouts{
ReadTimeout: 1 * time.Second,
ReadTimeoutSet: true,
ReadHeaderTimeout: 2 * time.Second,
ReadHeaderTimeoutSet: true,
}}},
expected: Timeouts{
ReadTimeout: 1 * time.Second,
ReadHeaderTimeout: 2 * time.Second,
WriteTimeout: defaultTimeouts.WriteTimeout,
IdleTimeout: defaultTimeouts.IdleTimeout,
},
},
{
group: []*SiteConfig{{Timeouts: Timeouts{
ReadTimeoutSet: true,
WriteTimeoutSet: true,
}}},
expected: Timeouts{
ReadTimeout: 0,
ReadHeaderTimeout: defaultTimeouts.ReadHeaderTimeout,
WriteTimeout: 0,
IdleTimeout: defaultTimeouts.IdleTimeout,
},
},
{
group: []*SiteConfig{
{Timeouts: Timeouts{
ReadTimeout: 2 * time.Second,
ReadTimeoutSet: true,
WriteTimeout: 2 * time.Second,
WriteTimeoutSet: true,
}},
{Timeouts: Timeouts{
ReadTimeout: 1 * time.Second,
ReadTimeoutSet: true,
WriteTimeout: 1 * time.Second,
WriteTimeoutSet: true,
}},
},
expected: Timeouts{
ReadTimeout: 1 * time.Second,
ReadHeaderTimeout: defaultTimeouts.ReadHeaderTimeout,
WriteTimeout: 1 * time.Second,
IdleTimeout: defaultTimeouts.IdleTimeout,
},
},
{
group: []*SiteConfig{{Timeouts: Timeouts{
ReadHeaderTimeout: 5 * time.Second,
ReadHeaderTimeoutSet: true,
IdleTimeout: 10 * time.Second,
IdleTimeoutSet: true,
}}},
expected: Timeouts{
ReadTimeout: defaultTimeouts.ReadTimeout,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: defaultTimeouts.WriteTimeout,
IdleTimeout: 10 * time.Second,
},
},
} {
actual := makeHTTPServerWithTimeouts("127.0.0.1:9005", tc.group)
if got, want := actual.Addr, "127.0.0.1:9005"; got != want {
t.Errorf("Test %d: Expected Addr=%s, but was %s", i, want, got)
}
if got, want := actual.ReadTimeout, tc.expected.ReadTimeout; got != want {
t.Errorf("Test %d: Expected ReadTimeout=%v, but was %v", i, want, got)
}
if got, want := actual.ReadHeaderTimeout, tc.expected.ReadHeaderTimeout; got != want {
t.Errorf("Test %d: Expected ReadHeaderTimeout=%v, but was %v", i, want, got)
}
if got, want := actual.WriteTimeout, tc.expected.WriteTimeout; got != want {
t.Errorf("Test %d: Expected WriteTimeout=%v, but was %v", i, want, got)
}
if got, want := actual.IdleTimeout, tc.expected.IdleTimeout; got != want {
t.Errorf("Test %d: Expected IdleTimeout=%v, but was %v", i, want, got)
}
}
}
func TestTrimPathPrefix(t *testing.T) {
for i, pt := range []struct {
url string
prefix string
expected string
shouldFail bool
}{
{
url: "/my/path",
prefix: "/my",
expected: "/path",
shouldFail: false,
},
{
url: "/my/%2f/path",
prefix: "/my",
expected: "/%2f/path",
shouldFail: false,
},
{
url: "/my/path",
prefix: "/my/",
expected: "/path",
shouldFail: false,
},
{
url: "/my///path",
prefix: "/my",
expected: "/path",
shouldFail: true,
},
{
url: "/my///path",
prefix: "/my",
expected: "///path",
shouldFail: false,
},
{
url: "/my/path///slash",
prefix: "/my",
expected: "/path///slash",
shouldFail: false,
},
{
url: "/my/%2f/path/%2f",
prefix: "/my",
expected: "/%2f/path/%2f",
shouldFail: false,
}, {
url: "/my/%20/path",
prefix: "/my",
expected: "/%20/path",
shouldFail: false,
}, {
url: "/path",
prefix: "",
expected: "/path",
shouldFail: false,
}, {
url: "/path/my/",
prefix: "/my",
expected: "/path/my/",
shouldFail: false,
}, {
url: "",
prefix: "/my",
expected: "/",
shouldFail: false,
}, {
url: "/apath",
prefix: "",
expected: "/apath",
shouldFail: false,
}, {
url: "/my/path/page.php?akey=value",
prefix: "/my",
expected: "/path/page.php?akey=value",
shouldFail: false,
}, {
url: "/my/path/page?key=value#fragment",
prefix: "/my",
expected: "/path/page?key=value#fragment",
shouldFail: false,
}, {
url: "/my/path/page#fragment",
prefix: "/my",
expected: "/path/page#fragment",
shouldFail: false,
}, {
url: "/my/apath?",
prefix: "/my",
expected: "/apath?",
shouldFail: false,
},
} {
u, _ := url.Parse(pt.url)
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.String() != want {
if !pt.shouldFail {
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.String())
}
} else if pt.shouldFail {
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.String())
}
}
}
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
for name, c := range map[string]struct {
group []*SiteConfig
expect int
}{
"disable": {
group: []*SiteConfig{{}},
expect: 0,
},
"oneSite": {
group: []*SiteConfig{{Limits: Limits{
MaxRequestHeaderSize: 100,
}}},
expect: 100,
},
"multiSites": {
group: []*SiteConfig{
{Limits: Limits{MaxRequestHeaderSize: 100}},
{Limits: Limits{MaxRequestHeaderSize: 50}},
},
expect: 50,
},
} {
c := c
t.Run(name, func(t *testing.T) {
actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
if got := actual.MaxHeaderBytes; got != c.expect {
t.Errorf("Expect %d, but got %d", c.expect, got)
}
})
}
}

View File

@@ -0,0 +1,151 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"time"
"github.com/mholt/caddy/caddytls"
)
// SiteConfig contains information about a site
// (also known as a virtual host).
type SiteConfig struct {
// The address of the site
Addr Address
// The list of viable index page names of the site
IndexPages []string
// The hostname to bind listener to;
// defaults to Addr.Host
ListenHost string
// TLS configuration
TLS *caddytls.Config
// If true, the Host header in the HTTP request must
// match the SNI value in the TLS handshake (if any).
// This should be enabled whenever a site relies on
// TLS client authentication, for example; or any time
// you want to enforce that THIS site's TLS config
// is used and not the TLS config of any other site
// on the same listener. TODO: Check how relevant this
// is with TLS 1.3.
StrictHostMatching bool
// Uncompiled middleware stack
middleware []Middleware
// Compiled middleware stack
middlewareChain Handler
// listener middleware stack
listenerMiddleware []ListenerMiddleware
// Directory from which to serve files
Root string
// A list of files to hide (for example, the
// source Caddyfile). TODO: Enforcing this
// should be centralized, for example, a
// standardized way of loading files from disk
// for a request.
HiddenFiles []string
// Max request's header/body size
Limits Limits
// The path to the Caddyfile used to generate this site config
originCaddyfile string
// These timeout values are used, in conjunction with other
// site configs on the same server instance, to set the
// respective timeout values on the http.Server that
// is created. Sensible values will mitigate slowloris
// attacks and overcome faulty networks, while still
// preserving functionality needed for proxying,
// websockets, etc.
Timeouts Timeouts
// If true, any requests not matching other site definitions
// may be served by this site.
FallbackSite bool
}
// Timeouts specify various timeouts for a server to use.
// If the assocated bool field is true, then the duration
// value should be treated literally (i.e. a zero-value
// duration would mean "no timeout"). If false, the duration
// was left unset, so a zero-value duration would mean to
// use a default value (even if default is non-zero).
type Timeouts struct {
ReadTimeout time.Duration
ReadTimeoutSet bool
ReadHeaderTimeout time.Duration
ReadHeaderTimeoutSet bool
WriteTimeout time.Duration
WriteTimeoutSet bool
IdleTimeout time.Duration
IdleTimeoutSet bool
}
// Limits specify size limit of request's header and body.
type Limits struct {
MaxRequestHeaderSize int64
MaxRequestBodySizes []PathLimit
}
// PathLimit is a mapping from a site's path to its corresponding
// maximum request body size (in bytes)
type PathLimit struct {
Path string
Limit int64
}
// AddMiddleware adds a middleware to a site's middleware stack.
func (s *SiteConfig) AddMiddleware(m Middleware) {
s.middleware = append(s.middleware, m)
}
// AddListenerMiddleware adds a listener middleware to a site's listenerMiddleware stack.
func (s *SiteConfig) AddListenerMiddleware(l ListenerMiddleware) {
s.listenerMiddleware = append(s.listenerMiddleware, l)
}
// TLSConfig returns s.TLS.
func (s SiteConfig) TLSConfig() *caddytls.Config {
return s.TLS
}
// Host returns s.Addr.Host.
func (s SiteConfig) Host() string {
return s.Addr.Host
}
// Port returns s.Addr.Port.
func (s SiteConfig) Port() string {
return s.Addr.Port
}
// Middleware returns s.middleware (useful for tests).
func (s SiteConfig) Middleware() []Middleware {
return s.middleware
}
// ListenerMiddleware returns s.listenerMiddleware
func (s SiteConfig) ListenerMiddleware() []ListenerMiddleware {
return s.listenerMiddleware
}

View File

@@ -0,0 +1,460 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bytes"
"crypto/rand"
"fmt"
"io/ioutil"
mathrand "math/rand"
"net"
"net/http"
"net/url"
"path"
"strings"
"sync"
"text/template"
"time"
"os"
"github.com/russross/blackfriday"
)
// This file contains the context and functions available for
// use in the templates.
// Context is the context with which Caddy templates are executed.
type Context struct {
Root http.FileSystem
Req *http.Request
URL *url.URL
Args []interface{} // defined by arguments to .Include
// just used for adding preload links for server push
responseHeader http.Header
}
// NewContextWithHeader creates a context with given response header.
//
// To plugin developer:
// The returned context's exported fileds remain empty,
// you should then initialize them if you want.
func NewContextWithHeader(rh http.Header) Context {
return Context{
responseHeader: rh,
}
}
// Include returns the contents of filename relative to the site root.
func (c Context) Include(filename string, args ...interface{}) (string, error) {
c.Args = args
return ContextInclude(filename, c, c.Root)
}
// Now returns the current timestamp in the specified format.
func (c Context) Now(format string) string {
return time.Now().Format(format)
}
// NowDate returns the current date/time that can be used
// in other time functions.
func (c Context) NowDate() time.Time {
return time.Now()
}
// Cookie gets the value of a cookie with name name.
func (c Context) Cookie(name string) string {
cookies := c.Req.Cookies()
for _, cookie := range cookies {
if cookie.Name == name {
return cookie.Value
}
}
return ""
}
// Header gets the value of a request header with field name.
func (c Context) Header(name string) string {
return c.Req.Header.Get(name)
}
// Hostname gets the (remote) hostname of the client making the request.
func (c Context) Hostname() string {
ip := c.IP()
hostnameList, err := net.LookupAddr(ip)
if err != nil || len(hostnameList) == 0 {
return c.Req.RemoteAddr
}
return hostnameList[0]
}
// Env gets a map of the environment variables.
func (c Context) Env() map[string]string {
osEnv := os.Environ()
envVars := make(map[string]string, len(osEnv))
for _, env := range osEnv {
data := strings.SplitN(env, "=", 2)
if len(data) == 2 && len(data[0]) > 0 {
envVars[data[0]] = data[1]
}
}
return envVars
}
// IP gets the (remote) IP address of the client making the request.
func (c Context) IP() string {
ip, _, err := net.SplitHostPort(c.Req.RemoteAddr)
if err != nil {
return c.Req.RemoteAddr
}
return ip
}
// To mock the net.InterfaceAddrs from the test.
var networkInterfacesFn = net.InterfaceAddrs
// ServerIP gets the (local) IP address of the server.
// TODO: The bind directive should be honored in this method (see PR #1474).
func (c Context) ServerIP() string {
addrs, err := networkInterfacesFn()
if err != nil {
return ""
}
for _, address := range addrs {
// Validate the address and check if it's not a loopback
if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil || ipnet.IP.To16() != nil {
return ipnet.IP.String()
}
}
}
return ""
}
// URI returns the raw, unprocessed request URI (including query
// string and hash) obtained directly from the Request-Line of
// the HTTP request.
func (c Context) URI() string {
return c.Req.RequestURI
}
// Host returns the hostname portion of the Host header
// from the HTTP request.
func (c Context) Host() (string, error) {
host, _, err := net.SplitHostPort(c.Req.Host)
if err != nil {
if !strings.Contains(c.Req.Host, ":") {
// common with sites served on the default port 80
return c.Req.Host, nil
}
return "", err
}
return host, nil
}
// Port returns the port portion of the Host header if specified.
func (c Context) Port() (string, error) {
_, port, err := net.SplitHostPort(c.Req.Host)
if err != nil {
if !strings.Contains(c.Req.Host, ":") {
// common with sites served on the default port 80
return HTTPPort, nil
}
return "", err
}
return port, nil
}
// Method returns the method (GET, POST, etc.) of the request.
func (c Context) Method() string {
return c.Req.Method
}
// PathMatches returns true if the path portion of the request
// URL matches pattern.
func (c Context) PathMatches(pattern string) bool {
return Path(c.Req.URL.Path).Matches(pattern)
}
// Truncate truncates the input string to the given length.
// If length is negative, it returns that many characters
// starting from the end of the string. If the absolute value
// of length is greater than len(input), the whole input is
// returned.
func (c Context) Truncate(input string, length int) string {
if length < 0 && len(input)+length > 0 {
return input[len(input)+length:]
}
if length >= 0 && len(input) > length {
return input[:length]
}
return input
}
// StripHTML returns s without HTML tags. It is fairly naive
// but works with most valid HTML inputs.
func (c Context) StripHTML(s string) string {
var buf bytes.Buffer
var inTag, inQuotes bool
var tagStart int
for i, ch := range s {
if inTag {
if ch == '>' && !inQuotes {
inTag = false
} else if ch == '<' && !inQuotes {
// false start
buf.WriteString(s[tagStart:i])
tagStart = i
} else if ch == '"' {
inQuotes = !inQuotes
}
continue
}
if ch == '<' {
inTag = true
tagStart = i
continue
}
buf.WriteRune(ch)
}
if inTag {
// false start
buf.WriteString(s[tagStart:])
}
return buf.String()
}
// Ext returns the suffix beginning at the final dot in the final
// slash-separated element of the pathStr (or in other words, the
// file extension).
func (c Context) Ext(pathStr string) string {
return path.Ext(pathStr)
}
// StripExt returns the input string without the extension,
// which is the suffix starting with the final '.' character
// but not before the final path separator ('/') character.
// If there is no extension, the whole input is returned.
func (c Context) StripExt(path string) string {
for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- {
if path[i] == '.' {
return path[:i]
}
}
return path
}
// Replace replaces instances of find in input with replacement.
func (c Context) Replace(input, find, replacement string) string {
return strings.Replace(input, find, replacement, -1)
}
// Markdown returns the HTML contents of the markdown contained in filename
// (relative to the site root).
func (c Context) Markdown(filename string) (string, error) {
body, err := c.Include(filename)
if err != nil {
return "", err
}
renderer := blackfriday.HtmlRenderer(0, "", "")
extns := 0
extns |= blackfriday.EXTENSION_TABLES
extns |= blackfriday.EXTENSION_FENCED_CODE
extns |= blackfriday.EXTENSION_STRIKETHROUGH
extns |= blackfriday.EXTENSION_DEFINITION_LISTS
markdown := blackfriday.Markdown([]byte(body), renderer, extns)
return string(markdown), nil
}
// ContextInclude opens filename using fs and executes a template with the context ctx.
// This does the same thing that Context.Include() does, but with the ability to provide
// your own context so that the included files can have access to additional fields your
// type may provide. You can embed Context in your type, then override its Include method
// to call this function with ctx being the instance of your type, and fs being Context.Root.
func ContextInclude(filename string, ctx interface{}, fs http.FileSystem) (string, error) {
file, err := fs.Open(filename)
if err != nil {
return "", err
}
defer file.Close()
body, err := ioutil.ReadAll(file)
if err != nil {
return "", err
}
tpl, err := template.New(filename).Funcs(TemplateFuncs).Parse(string(body))
if err != nil {
return "", err
}
buf := includeBufs.Get().(*bytes.Buffer)
buf.Reset()
defer includeBufs.Put(buf)
err = tpl.Execute(buf, ctx)
if err != nil {
return "", err
}
return buf.String(), nil
}
// ToLower will convert the given string to lower case.
func (c Context) ToLower(s string) string {
return strings.ToLower(s)
}
// ToUpper will convert the given string to upper case.
func (c Context) ToUpper(s string) string {
return strings.ToUpper(s)
}
// Split is a pass-through to strings.Split. It will split the first argument at each instance of the separator and return a slice of strings.
func (c Context) Split(s string, sep string) []string {
return strings.Split(s, sep)
}
// Join is a pass-through to strings.Join. It will join the first argument slice with the separator in the second argument and return the result.
func (c Context) Join(a []string, sep string) string {
return strings.Join(a, sep)
}
// Slice will convert the given arguments into a slice.
func (c Context) Slice(elems ...interface{}) []interface{} {
return elems
}
// Map will convert the arguments into a map. It expects alternating string keys and values. This is useful for building more complicated data structures
// if you are using subtemplates or things like that.
func (c Context) Map(values ...interface{}) (map[string]interface{}, error) {
if len(values)%2 != 0 {
return nil, fmt.Errorf("Map expects an even number of arguments")
}
dict := make(map[string]interface{}, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].(string)
if !ok {
return nil, fmt.Errorf("Map keys must be strings")
}
dict[key] = values[i+1]
}
return dict, nil
}
// Files reads and returns a slice of names from the given directory
// relative to the root of Context c.
func (c Context) Files(name string) ([]string, error) {
dir, err := c.Root.Open(path.Clean(name))
if err != nil {
return nil, err
}
defer dir.Close()
stat, err := dir.Stat()
if err != nil {
return nil, err
}
if !stat.IsDir() {
return nil, fmt.Errorf("%v is not a directory", name)
}
dirInfo, err := dir.Readdir(0)
if err != nil {
return nil, err
}
names := make([]string, len(dirInfo))
for i, fileInfo := range dirInfo {
names[i] = fileInfo.Name()
}
return names, nil
}
// IsMITM returns true if it seems likely that the TLS connection
// is being intercepted.
func (c Context) IsMITM() bool {
if val, ok := c.Req.Context().Value(MitmCtxKey).(bool); ok {
return val
}
return false
}
// RandomString generates a random string of random length given
// length bounds. Thanks to http://stackoverflow.com/a/35615565/1048862
// for the clever technique that is fairly fast, secure, and maintains
// proper distributions over the dictionary.
func (c Context) RandomString(minLen, maxLen int) string {
const (
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
letterIdxBits = 6 // 6 bits to represent 64 possibilities (indexes)
letterIdxMask = 1<<letterIdxBits - 1 // all 1-bits, as many as letterIdxBits
)
if minLen < 0 || maxLen < 0 || maxLen < minLen {
return ""
}
n := mathrand.Intn(maxLen-minLen+1) + minLen // choose actual length
// secureRandomBytes returns a number of bytes using crypto/rand.
secureRandomBytes := func(numBytes int) []byte {
randomBytes := make([]byte, numBytes)
rand.Read(randomBytes)
return randomBytes
}
result := make([]byte, n)
bufferSize := int(float64(n) * 1.3)
for i, j, randomBytes := 0, 0, []byte{}; i < n; j++ {
if j%bufferSize == 0 {
randomBytes = secureRandomBytes(bufferSize)
}
if idx := int(randomBytes[j%n] & letterIdxMask); idx < len(letterBytes) {
result[i] = letterBytes[idx]
i++
}
}
return string(result)
}
// AddLink adds a link header in response
// see https://www.w3.org/wiki/LinkHeader
func (c Context) AddLink(link string) string {
if c.responseHeader == nil {
return ""
}
c.responseHeader.Add("Link", link)
return ""
}
// buffer pool for .Include context actions
var includeBufs = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// TemplateFuncs contains user-defined functions
// for execution in templates.
var TemplateFuncs = template.FuncMap{}

View File

@@ -0,0 +1,924 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"testing"
"time"
"text/template"
)
func TestInclude(t *testing.T) {
context := getContextOrFail(t)
inputFilename := "test_file"
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
defer func() {
err := os.Remove(absInFilePath)
if err != nil && !os.IsNotExist(err) {
t.Fatalf("Failed to clean test file!")
}
}()
tests := []struct {
args []interface{}
fileContent string
expectedContent string
shouldErr bool
expectedErrorContent string
}{
// Test 0 - all good
{
fileContent: `str1 {{ .Root }} str2`,
expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
shouldErr: false,
expectedErrorContent: "",
},
// Test 1 - all good, with args
{
args: []interface{}{"hello", 5},
fileContent: `str1 {{ .Root }} str2 {{index .Args 0}} {{index .Args 1}}`,
expectedContent: fmt.Sprintf("str1 %s str2 %s %d", context.Root, "hello", 5),
shouldErr: false,
expectedErrorContent: "",
},
// Test 2 - failure on template.Parse
{
fileContent: `str1 {{ .Root } str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `unexpected "}" in operand`,
},
// Test 3 - failure on template.Execute
{
fileContent: `str1 {{ .InvalidField }} str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `InvalidField`,
},
{
fileContent: `str1 {{ .InvalidField }} str2`,
expectedContent: "",
shouldErr: true,
expectedErrorContent: `type httpserver.Context`,
},
// Test 4 - all good, with custom function
{
fileContent: `hello {{ caddy }}`,
expectedContent: "hello caddy",
shouldErr: false,
expectedErrorContent: "",
},
}
TemplateFuncs["caddy"] = func() string { return "caddy" }
for i, test := range tests {
testPrefix := getTestPrefix(i)
// WriteFile truncates the contentt
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
if err != nil {
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
}
content, err := context.Include(inputFilename, test.args...)
if err != nil {
if !test.shouldErr {
t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
}
if !strings.Contains(err.Error(), test.expectedErrorContent) {
t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
}
}
if err == nil && test.shouldErr {
t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
}
if content != test.expectedContent {
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
}
}
}
func TestIncludeNotExisting(t *testing.T) {
context := getContextOrFail(t)
_, err := context.Include("not_existing")
if err == nil {
t.Errorf("Expected error but found nil!")
}
}
func TestMarkdown(t *testing.T) {
context := getContextOrFail(t)
inputFilename := "test_file"
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
defer func() {
err := os.Remove(absInFilePath)
if err != nil && !os.IsNotExist(err) {
t.Fatalf("Failed to clean test file!")
}
}()
tests := []struct {
fileContent string
expectedContent string
}{
// Test 0 - test parsing of markdown
{
fileContent: "* str1\n* str2\n",
expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// WriteFile truncates the contentt
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
if err != nil {
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
}
content, _ := context.Markdown(inputFilename)
if content != test.expectedContent {
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
}
}
}
func TestCookie(t *testing.T) {
tests := []struct {
cookie *http.Cookie
cookieName string
expectedValue string
}{
// Test 0 - happy path
{
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
cookieName: "cookieName",
expectedValue: "cookieValue",
},
// Test 1 - try to get a non-existing cookie
{
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
cookieName: "notExisting",
expectedValue: "",
},
// Test 2 - partial name match
{
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
cookieName: "cook",
expectedValue: "",
},
// Test 3 - cookie with optional fields
{
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
cookieName: "cookie",
expectedValue: "cookieValue",
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
// reinitialize the context for each test
context := getContextOrFail(t)
context.Req.AddCookie(test.cookie)
actualCookieVal := context.Cookie(test.cookieName)
if actualCookieVal != test.expectedValue {
t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
}
}
}
func TestCookieMultipleCookies(t *testing.T) {
context := getContextOrFail(t)
cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
// make sure that there's no state and multiple requests for different cookies return the correct result
for i := 0; i < 10; i++ {
context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
}
for i := 0; i < 10; i++ {
expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
if actualCookieVal != expectedCookieVal {
t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
}
}
}
func TestHeader(t *testing.T) {
context := getContextOrFail(t)
headerKey, headerVal := "Header1", "HeaderVal1"
context.Req.Header.Add(headerKey, headerVal)
actualHeaderVal := context.Header(headerKey)
if actualHeaderVal != headerVal {
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
}
missingHeaderVal := context.Header("not-existing")
if missingHeaderVal != "" {
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
}
}
func TestHostname(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputRemoteAddr string
expectedHostname string
}{
// TODO(mholt): Fix these tests, they're not portable. i.e. my resolver
// returns "fwdr-8.fwdr-8.fwdr-8.fwdr-8." instead of these google ones.
// Test 0 - ipv4 with port
// {"8.8.8.8:1111", "google-public-dns-a.google.com."},
// // Test 1 - ipv4 without port
// {"8.8.8.8", "google-public-dns-a.google.com."},
// // Test 2 - ipv6 with port
// {"[2001:4860:4860::8888]:11", "google-public-dns-a.google.com."},
// // Test 3 - ipv6 without port and brackets
// {"2001:4860:4860::8888", "google-public-dns-a.google.com."},
// Test 4 - no hostname available
{"0.0.0.0", "0.0.0.0"},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
context.Req.RemoteAddr = test.inputRemoteAddr
actualHostname := context.Hostname()
if actualHostname != test.expectedHostname {
t.Errorf(testPrefix+"Expected hostname %s, found %s", test.expectedHostname, actualHostname)
}
}
}
func TestEnv(t *testing.T) {
context := getContextOrFail(t)
name := "ENV_TEST_NAME"
testValue := "TEST_VALUE"
os.Setenv(name, testValue)
notExisting := "ENV_TEST_NOT_EXISTING"
os.Unsetenv(notExisting)
invalidName := "ENV_TEST_INVALID_NAME"
os.Setenv("="+invalidName, testValue)
env := context.Env()
if value := env[name]; value != testValue {
t.Errorf("Expected env-variable %s value '%s', found '%s'",
name, testValue, value)
}
if value, ok := env[notExisting]; ok {
t.Errorf("Expected empty env-variable %s, found '%s'",
notExisting, value)
}
for k, v := range env {
if strings.Contains(k, invalidName) {
t.Errorf("Expected invalid name not to be included in Env %s, found in key '%s'", invalidName, k)
}
if strings.Contains(v, invalidName) {
t.Errorf("Expected invalid name not be be included in Env %s, found in value '%s'", invalidName, v)
}
}
os.Unsetenv("=" + invalidName)
}
func TestIP(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputRemoteAddr string
expectedIP string
}{
// Test 0 - ipv4 with port
{"1.1.1.1:1111", "1.1.1.1"},
// Test 1 - ipv4 without port
{"1.1.1.1", "1.1.1.1"},
// Test 2 - ipv6 with port
{"[::1]:11", "::1"},
// Test 3 - ipv6 without port and brackets
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
// Test 4 - ipv6 with zone and port
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
context.Req.RemoteAddr = test.inputRemoteAddr
actualIP := context.IP()
if actualIP != test.expectedIP {
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
}
}
}
type myIP string
func (ip myIP) mockInterfaces() ([]net.Addr, error) {
a := net.ParseIP(string(ip))
return []net.Addr{
&net.IPNet{IP: a, Mask: nil},
}, nil
}
func TestServerIP(t *testing.T) {
context := getContextOrFail(t)
tests := []string{
// Test 0 - ipv4
"1.1.1.1",
// Test 1 - ipv6
"2001:db8:a0b:12f0::1",
}
for i, expectedIP := range tests {
testPrefix := getTestPrefix(i)
// Mock the network interface
ip := myIP(expectedIP)
networkInterfacesFn = ip.mockInterfaces
defer func() {
networkInterfacesFn = net.InterfaceAddrs
}()
actualIP := context.ServerIP()
if actualIP != expectedIP {
t.Errorf("%sExpected IP \"%s\", found \"%s\".", testPrefix, expectedIP, actualIP)
}
}
}
func TestURL(t *testing.T) {
context := getContextOrFail(t)
inputURL := "http://localhost"
context.Req.RequestURI = inputURL
if inputURL != context.URI() {
t.Errorf("Expected url %s, found %s", inputURL, context.URI())
}
}
func TestHost(t *testing.T) {
tests := []struct {
input string
expectedHost string
shouldErr bool
}{
{
input: "localhost:123",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "localhost",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "[::]",
expectedHost: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
}
}
func TestPort(t *testing.T) {
tests := []struct {
input string
expectedPort string
shouldErr bool
}{
{
input: "localhost:123",
expectedPort: "123",
shouldErr: false,
},
{
input: "localhost",
expectedPort: "80", // assuming 80 is the default port
shouldErr: false,
},
{
input: ":8080",
expectedPort: "8080",
shouldErr: false,
},
{
input: "[::]",
expectedPort: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
}
}
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
context := getContextOrFail(t)
context.Req.Host = input
var actualResult, testedObject string
var err error
if isTestingHost {
actualResult, err = context.Host()
testedObject = "host"
} else {
actualResult, err = context.Port()
testedObject = "port"
}
if shouldErr && err == nil {
t.Errorf("Expected error, found nil!")
return
}
if !shouldErr && err != nil {
t.Errorf("Expected no error, found %s", err)
return
}
if actualResult != expectedResult {
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
}
}
func TestMethod(t *testing.T) {
context := getContextOrFail(t)
method := "POST"
context.Req.Method = method
if method != context.Method() {
t.Errorf("Expected method %s, found %s", method, context.Method())
}
}
func TestContextPathMatches(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
urlStr string
pattern string
shouldMatch bool
}{
// Test 0
{
urlStr: "http://localhost/",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost/",
pattern: "/",
shouldMatch: true,
},
// Test 3
{
urlStr: "http://localhost/?param=val",
pattern: "/",
shouldMatch: true,
},
// Test 4
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir2",
shouldMatch: false,
},
// Test 5
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 6
{
urlStr: "http://localhost:444/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 7
{
urlStr: "http://localhost/dir1/dir2",
pattern: "*/dir2",
shouldMatch: false,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
var err error
context.Req.URL, err = url.Parse(test.urlStr)
if err != nil {
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
}
matches := context.PathMatches(test.pattern)
if matches != test.shouldMatch {
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
}
}
}
func TestTruncate(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
inputString string
inputLength int
expected string
}{
// Test 0 - small length
{
inputString: "string",
inputLength: 1,
expected: "s",
},
// Test 1 - exact length
{
inputString: "string",
inputLength: 6,
expected: "string",
},
// Test 2 - bigger length
{
inputString: "string",
inputLength: 10,
expected: "string",
},
// Test 3 - zero length
{
inputString: "string",
inputLength: 0,
expected: "",
},
// Test 4 - negative, smaller length
{
inputString: "string",
inputLength: -5,
expected: "tring",
},
// Test 5 - negative, exact length
{
inputString: "string",
inputLength: -6,
expected: "string",
},
// Test 6 - negative, bigger length
{
inputString: "string",
inputLength: -7,
expected: "string",
},
}
for i, test := range tests {
actual := context.Truncate(test.inputString, test.inputLength)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
}
}
}
func TestStripHTML(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
input string
expected string
}{
// Test 0 - no tags
{
input: `h1`,
expected: `h1`,
},
// Test 1 - happy path
{
input: `<h1>h1</h1>`,
expected: `h1`,
},
// Test 2 - tag in quotes
{
input: `<h1">">h1</h1>`,
expected: `h1`,
},
// Test 3 - multiple tags
{
input: `<h1><b>h1</b></h1>`,
expected: `h1`,
},
// Test 4 - tags not closed
{
input: `<h1`,
expected: `<h1`,
},
// Test 5 - false start
{
input: `<h1<b>hi`,
expected: `<h1hi`,
},
}
for i, test := range tests {
actual := context.StripHTML(test.input)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
}
}
}
func TestStripExt(t *testing.T) {
context := getContextOrFail(t)
tests := []struct {
input string
expected string
}{
// Test 0 - empty input
{
input: "",
expected: "",
},
// Test 1 - relative file with ext
{
input: "file.ext",
expected: "file",
},
// Test 2 - relative file without ext
{
input: "file",
expected: "file",
},
// Test 3 - absolute file without ext
{
input: "/file",
expected: "/file",
},
// Test 4 - absolute file with ext
{
input: "/file.ext",
expected: "/file",
},
// Test 5 - with ext but ends with /
{
input: "/dir.ext/",
expected: "/dir.ext/",
},
// Test 6 - file with ext under dir with ext
{
input: "/dir.ext/file.ext",
expected: "/dir.ext/file",
},
}
for i, test := range tests {
actual := context.StripExt(test.input)
if actual != test.expected {
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
}
}
}
func initTestContext() (Context, error) {
body := bytes.NewBufferString("request body")
request, err := http.NewRequest("GET", "https://localhost", body)
if err != nil {
return Context{}, err
}
res := httptest.NewRecorder()
return Context{Root: http.Dir(os.TempDir()), responseHeader: res.Header(), Req: request}, nil
}
func getContextOrFail(t *testing.T) Context {
context, err := initTestContext()
if err != nil {
t.Fatalf("Failed to prepare test context")
}
return context
}
func getTestPrefix(testN int) string {
return fmt.Sprintf("Test [%d]: ", testN)
}
func TestTemplates(t *testing.T) {
tests := []struct{ tmpl, expected string }{
{`{{.ToUpper "aAA"}}`, "AAA"},
{`{{"bbb" | .ToUpper}}`, "BBB"},
{`{{.ToLower "CCc"}}`, "ccc"},
{`{{range (.Split "a,b,c" ",")}}{{.}}{{end}}`, "abc"},
{`{{range .Split "a,b,c" ","}}{{.}}{{end}}`, "abc"},
{`{{range .Slice "a" "b" "c"}}{{.}}{{end}}`, "abc"},
{`{{with .Map "A" "a" "B" "b" "c" "d"}}{{.A}}{{.B}}{{.c}}{{end}}`, "abd"},
}
for i, test := range tests {
ctx := getContextOrFail(t)
tmpl, err := template.New("").Parse(test.tmpl)
if err != nil {
t.Errorf("Test %d: %s", i, err)
continue
}
buf := &bytes.Buffer{}
err = tmpl.Execute(buf, ctx)
if err != nil {
t.Errorf("Test %d: %s", i, err)
continue
}
if buf.String() != test.expected {
t.Errorf("Test %d: Results do not match. '%s' != '%s'", i, buf.String(), test.expected)
}
}
}
func TestFiles(t *testing.T) {
tests := []struct {
fileNames []string
inputBase string
shouldErr bool
verifyErr func(error) bool
}{
// Test 1 - directory and files exist
{
fileNames: []string{"file1", "file2"},
shouldErr: false,
},
// Test 2 - directory exists, no files
{
fileNames: []string{},
shouldErr: false,
},
// Test 3 - file or directory does not exist
{
fileNames: nil,
inputBase: "doesNotExist",
shouldErr: true,
verifyErr: os.IsNotExist,
},
// Test 4 - directory and files exist, but path to a file
{
fileNames: []string{"file1", "file2"},
inputBase: "file1",
shouldErr: true,
verifyErr: func(err error) bool {
return strings.HasSuffix(err.Error(), "is not a directory")
},
},
// Test 5 - try to leave Context Root
{
fileNames: nil,
inputBase: filepath.Join("..", "..", "..", "..", "..", "etc"),
shouldErr: true,
verifyErr: os.IsNotExist,
},
}
for i, test := range tests {
context := getContextOrFail(t)
testPrefix := getTestPrefix(i + 1)
var dirPath string
var err error
// Create directory / files from test case.
if test.fileNames != nil {
dirPath, err = ioutil.TempDir(fmt.Sprintf("%s", context.Root), "caddy_ctxtest")
if err != nil {
os.RemoveAll(dirPath)
t.Fatalf(testPrefix+"Expected no error creating directory, got: '%s'", err.Error())
}
for _, name := range test.fileNames {
absFilePath := filepath.Join(dirPath, name)
if err = ioutil.WriteFile(absFilePath, []byte(""), os.ModePerm); err != nil {
os.RemoveAll(dirPath)
t.Fatalf(testPrefix+"Expected no error creating file, got: '%s'", err.Error())
}
}
}
// Perform test case.
input := filepath.ToSlash(filepath.Join(filepath.Base(dirPath), test.inputBase))
actual, err := context.Files(input)
if err != nil {
if !test.shouldErr {
t.Errorf(testPrefix+"Expected no error, got: '%s'", err.Error())
} else if !test.verifyErr(err) {
t.Errorf(testPrefix+"Could not verify error content, got: '%s'", err.Error())
}
} else if test.shouldErr {
t.Errorf(testPrefix + "Expected error but had none")
} else {
numFiles := len(test.fileNames)
// reflect.DeepEqual does not consider two empty slices to be equal
if numFiles == 0 && len(actual) != 0 {
t.Errorf(testPrefix+"Expected files %v, got: %v",
test.fileNames, actual)
} else {
sort.Strings(actual)
if numFiles > 0 && !reflect.DeepEqual(test.fileNames, actual) {
t.Errorf(testPrefix+"Expected files %v, got: %v",
test.fileNames, actual)
}
}
}
if dirPath != "" {
if err := os.RemoveAll(dirPath); err != nil && !os.IsNotExist(err) {
t.Fatalf(testPrefix+"Expected no error removing directory, got: '%s'", err.Error())
}
}
}
}
func TestAddLink(t *testing.T) {
for name, c := range map[string]struct {
input string
expectLinks []string
}{
"oneLink": {
input: `{{.AddLink "</test.css>; rel=preload"}}`,
expectLinks: []string{"</test.css>; rel=preload"},
},
"multipleLinks": {
input: `{{.AddLink "</test1.css>; rel=preload"}} {{.AddLink "</test2.css>; rel=meta"}}`,
expectLinks: []string{"</test1.css>; rel=preload", "</test2.css>; rel=meta"},
},
} {
c := c
t.Run(name, func(t *testing.T) {
ctx := getContextOrFail(t)
tmpl, err := template.New("").Parse(c.input)
if err != nil {
t.Fatal(err)
}
err = tmpl.Execute(ioutil.Discard, ctx)
if err != nil {
t.Fatal(err)
}
if got := ctx.responseHeader["Link"]; !reflect.DeepEqual(got, c.expectLinks) {
t.Errorf("Result not match: expect %v, but got %v", c.expectLinks, got)
}
})
}
}

View File

@@ -0,0 +1,175 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"net"
"strings"
)
// vhostTrie facilitates virtual hosting. It matches
// requests first by hostname (with support for
// wildcards as TLS certificates support them), then
// by longest matching path.
type vhostTrie struct {
fallbackHosts []string
edges map[string]*vhostTrie
site *SiteConfig // site to match on this node; also known as a virtual host
path string // the path portion of the key for the associated site
}
// newVHostTrie returns a new vhostTrie.
func newVHostTrie() *vhostTrie {
return &vhostTrie{edges: make(map[string]*vhostTrie), fallbackHosts: []string{"0.0.0.0", ""}}
}
// Insert adds stack to t keyed by key. The key should be
// a valid "host/path" combination (or just host).
func (t *vhostTrie) Insert(key string, site *SiteConfig) {
host, path := t.splitHostPath(key)
if _, ok := t.edges[host]; !ok {
t.edges[host] = newVHostTrie()
}
t.edges[host].insertPath(path, path, site)
}
// insertPath expects t to be a host node (not a root node),
// and inserts site into the t according to remainingPath.
func (t *vhostTrie) insertPath(remainingPath, originalPath string, site *SiteConfig) {
if remainingPath == "" {
t.site = site
t.path = originalPath
return
}
ch := string(remainingPath[0])
if _, ok := t.edges[ch]; !ok {
t.edges[ch] = newVHostTrie()
}
t.edges[ch].insertPath(remainingPath[1:], originalPath, site)
}
// Match returns the virtual host (site) in v with
// the closest match to key. If there was a match,
// it returns the SiteConfig and the path portion of
// the key used to make the match. The matched path
// would be a prefix of the path portion of the
// key, if not the whole path portion of the key.
// If there is no match, nil and empty string will
// be returned.
//
// A typical key will be in the form "host" or "host/path".
func (t *vhostTrie) Match(key string) (*SiteConfig, string) {
host, path := t.splitHostPath(key)
// try the given host, then, if no match, try fallback hosts
branch := t.matchHost(host)
for _, h := range t.fallbackHosts {
if branch != nil {
break
}
branch = t.matchHost(h)
}
if branch == nil {
return nil, ""
}
node := branch.matchPath(path)
if node == nil {
return nil, ""
}
return node.site, node.path
}
// matchHost returns the vhostTrie matching host. The matching
// algorithm is the same as used to match certificates to host
// with SNI during TLS handshakes. In other words, it supports,
// to some degree, the use of wildcard (*) characters.
func (t *vhostTrie) matchHost(host string) *vhostTrie {
// try exact match
if subtree, ok := t.edges[host]; ok {
return subtree
}
// then try replacing labels in the host
// with wildcards until we get a match
labels := strings.Split(host, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if subtree, ok := t.edges[candidate]; ok {
return subtree
}
}
return nil
}
// matchPath traverses t until it finds the longest key matching
// remainingPath, and returns its node.
func (t *vhostTrie) matchPath(remainingPath string) *vhostTrie {
var longestMatch *vhostTrie
for len(remainingPath) > 0 {
ch := string(remainingPath[0])
next, ok := t.edges[ch]
if !ok {
break
}
if next.site != nil {
longestMatch = next
}
t = next
remainingPath = remainingPath[1:]
}
return longestMatch
}
// splitHostPath separates host from path in key.
func (t *vhostTrie) splitHostPath(key string) (host, path string) {
parts := strings.SplitN(key, "/", 2)
host, path = strings.ToLower(parts[0]), "/"
if len(parts) > 1 {
path += parts[1]
}
// strip out the port (if present) from the host, since
// each port has its own socket, and each socket has its
// own listener, and each listener has its own server
// instance, and each server instance has its own vhosts.
// removing the port is a simple way to standardize so
// when requests come in, we can be sure to get a match.
hostname, _, err := net.SplitHostPort(host)
if err == nil {
host = hostname
}
return
}
// String returns a list of all the entries in t; assumes that
// t is a root node.
func (t *vhostTrie) String() string {
var s string
for host, edge := range t.edges {
s += edge.str(host)
}
return s
}
func (t *vhostTrie) str(prefix string) string {
var s string
for key, edge := range t.edges {
if edge.site != nil {
s += prefix + key + "\n"
}
s += edge.str(prefix + key)
}
return s
}

View File

@@ -0,0 +1,155 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httpserver
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestVHostTrie(t *testing.T) {
trie := newVHostTrie()
populateTestTrie(trie, []string{
"example",
"example.com",
"*.example.com",
"example.com/foo",
"example.com/foo/bar",
"*.example.com/test",
})
assertTestTrie(t, trie, []vhostTrieTest{
{"not-in-trie.com", false, "", "/"},
{"example", true, "example", "/"},
{"example.com", true, "example.com", "/"},
{"example.com/test", true, "example.com", "/"},
{"example.com/foo", true, "example.com/foo", "/foo"},
{"example.com/foo/", true, "example.com/foo", "/foo"},
{"EXAMPLE.COM/foo", true, "example.com/foo", "/foo"},
{"EXAMPLE.COM/Foo", true, "example.com", "/"},
{"example.com/foo/bar", true, "example.com/foo/bar", "/foo/bar"},
{"example.com/foo/bar/baz", true, "example.com/foo/bar", "/foo/bar"},
{"example.com/foo/other", true, "example.com/foo", "/foo"},
{"foo.example.com", true, "*.example.com", "/"},
{"foo.example.com/else", true, "*.example.com", "/"},
}, false)
}
func TestVHostTrieWildcard1(t *testing.T) {
trie := newVHostTrie()
populateTestTrie(trie, []string{
"example.com",
"",
})
assertTestTrie(t, trie, []vhostTrieTest{
{"not-in-trie.com", true, "", "/"},
{"example.com", true, "example.com", "/"},
{"example.com/foo", true, "example.com", "/"},
{"not-in-trie.com/asdf", true, "", "/"},
}, true)
}
func TestVHostTrieWildcard2(t *testing.T) {
trie := newVHostTrie()
populateTestTrie(trie, []string{
"0.0.0.0/asdf",
})
assertTestTrie(t, trie, []vhostTrieTest{
{"example.com/asdf/foo", true, "0.0.0.0/asdf", "/asdf"},
{"example.com/foo", false, "", "/"},
{"host/asdf", true, "0.0.0.0/asdf", "/asdf"},
}, true)
}
func TestVHostTrieWildcard3(t *testing.T) {
trie := newVHostTrie()
populateTestTrie(trie, []string{
"*/foo",
})
assertTestTrie(t, trie, []vhostTrieTest{
{"example.com/foo", true, "*/foo", "/foo"},
{"example.com", false, "", "/"},
}, true)
}
func TestVHostTriePort(t *testing.T) {
// Make sure port is stripped out
trie := newVHostTrie()
populateTestTrie(trie, []string{
"example.com:1234",
})
assertTestTrie(t, trie, []vhostTrieTest{
{"example.com/foo", true, "example.com:1234", "/"},
}, true)
}
func populateTestTrie(trie *vhostTrie, keys []string) {
for _, key := range keys {
// we wrap this in a func, passing in the key, otherwise the
// handler always writes the last key to the response, even
// if the handler is actually from one of the earlier keys.
func(key string) {
site := &SiteConfig{
middlewareChain: HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.Write([]byte(key))
return 0, nil
}),
}
trie.Insert(key, site)
}(key)
}
}
type vhostTrieTest struct {
query string
expectMatch bool
expectedKey string
matchedPrefix string // the path portion of a key that is expected to be matched
}
func assertTestTrie(t *testing.T, trie *vhostTrie, tests []vhostTrieTest, hasWildcardHosts bool) {
for i, test := range tests {
site, pathPrefix := trie.Match(test.query)
if !test.expectMatch {
if site != nil {
// If not expecting a value, then just make sure we didn't get one
t.Errorf("Test %d: Expected no matches, but got %v", i, site)
}
continue
}
// Otherwise, we must assert we got a value
if site == nil {
t.Errorf("Test %d: Expected non-nil return value, but got: %v", i, site)
continue
}
// And it must be the correct value
resp := httptest.NewRecorder()
site.middlewareChain.ServeHTTP(resp, nil)
actualHandlerKey := resp.Body.String()
if actualHandlerKey != test.expectedKey {
t.Errorf("Test %d: Expected match '%s' but matched '%s'",
i, test.expectedKey, actualHandlerKey)
}
// The path prefix must also be correct
if test.matchedPrefix != pathPrefix {
t.Errorf("Test %d: Expected matched path prefix to be '%s', got '%s'",
i, test.matchedPrefix, pathPrefix)
}
}
}