cloudflared/tunnelrpc/pogs/auth_test.go
Devin Carr eb2e4349e8 TUN-8415: Refactor capnp rpc into a single module
Combines the tunnelrpc and quic/schema capnp files into the same module.

To help reduce future issues with capnp id generation, capnpids are
provided in the capnp files from the existing capnp struct ids generated
in the go files.

Reduces the overall interface of the Capnp methods to the rest of
the code by providing an interface that will handle the quic protocol
selection.

Introduces a new `rpc-timeout` config that will allow all of the
SessionManager and ConfigurationManager RPC requests to have a timeout.
The timeout for these values is set to 5 seconds as non of these operations
for the managers should take a long time to complete.

Removed the RPC-specific logger as it never provided good debugging value
as the RPC method names were not visible in the logs.
2024-05-17 11:22:07 -07:00

137 lines
3.6 KiB
Go

package pogs
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
capnp "zombiezen.com/go/capnproto2"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
// Ensure the AuthOutcome sum is correct
var _ AuthOutcome = &AuthSuccess{}
var _ AuthOutcome = &AuthFail{}
var _ AuthOutcome = &AuthUnknown{}
// Unit tests for AuthenticateResponse.Outcome()
func TestAuthenticateResponseOutcome(t *testing.T) {
type fields struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
tests := []struct {
name string
fields fields
want AuthOutcome
}{
{"success",
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6},
},
{"fail",
fields{PermanentErr: "bad creds"},
AuthFail{err: fmt.Errorf("bad creds")},
},
{"error",
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6},
},
{"nil (no fields are set)",
fields{},
nil,
},
{"nil (too few fields are set)",
fields{HoursUntilRefresh: 6},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ar := AuthenticateResponse{
PermanentErr: tt.fields.PermanentErr,
RetryableErr: tt.fields.RetryableErr,
Jwt: tt.fields.Jwt,
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
}
got := ar.Outcome()
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
}
if got != nil && !reflect.DeepEqual(got.Serialize(), ar) {
t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize())
}
})
}
}
func TestAuthSuccess(t *testing.T) {
input := NewAuthSuccess([]byte("asdf"), 6)
output, ok := input.Serialize().Outcome().(AuthSuccess)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestAuthUnknown(t *testing.T) {
input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6)
output, ok := input.Serialize().Outcome().(AuthUnknown)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestAuthFail(t *testing.T) {
input := NewAuthFail(fmt.Errorf("wrong creds"))
output, ok := input.Serialize().Outcome().(AuthFail)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestWhenToRefresh(t *testing.T) {
expected := 4 * time.Hour
actual := hoursToTime(4)
if expected != actual {
t.Fatalf("expected %v hours, got %v", expected, actual)
}
}
// Test that serializing and deserializing AuthenticationResponse undo each other.
func TestSerializeAuthenticationResponse(t *testing.T) {
tests := []*AuthenticateResponse{
{
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
HoursUntilRefresh: 24,
},
{
PermanentErr: "bad auth",
},
{
RetryableErr: "bad connection",
HoursUntilRefresh: 24,
},
}
for i, testCase := range tests {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
assert.NoError(t, err)
capnpEntity, err := proto.NewAuthenticateResponse(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalAuthenticateResponse(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalAuthenticateResponse(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}