cloudflared/quic/safe_stream_test.go
Devin Carr eb2e4349e8 TUN-8415: Refactor capnp rpc into a single module
Combines the tunnelrpc and quic/schema capnp files into the same module.

To help reduce future issues with capnp id generation, capnpids are
provided in the capnp files from the existing capnp struct ids generated
in the go files.

Reduces the overall interface of the Capnp methods to the rest of
the code by providing an interface that will handle the quic protocol
selection.

Introduces a new `rpc-timeout` config that will allow all of the
SessionManager and ConfigurationManager RPC requests to have a timeout.
The timeout for these values is set to 5 seconds as non of these operations
for the managers should take a long time to complete.

Removed the RPC-specific logger as it never provided good debugging value
as the RPC method names were not visible in the logs.
2024-05-17 11:22:07 -07:00

179 lines
4.2 KiB
Go

package quic
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"io"
"math/big"
"net"
"sync"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
var (
testTLSServerConfig = GenerateTLSConfig()
testQUICConfig = &quic.Config{
KeepAlivePeriod: 5 * time.Second,
EnableDatagrams: true,
}
exchanges = 1000
msgsPerExchange = 10
testMsg = "Ok message"
)
func TestSafeStreamClose(t *testing.T) {
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
require.NoError(t, err)
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
require.NoError(t, err)
defer udpListener.Close()
var serverReady sync.WaitGroup
serverReady.Add(1)
var done sync.WaitGroup
done.Add(1)
go func() {
defer done.Done()
quicServer(t, &serverReady, udpListener)
}()
done.Add(1)
go func() {
serverReady.Wait()
defer done.Done()
quicClient(t, udpListener.LocalAddr())
}()
done.Wait()
}
func quicClient(t *testing.T, addr net.Addr) {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"argotunnel"},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
session, err := quic.DialAddr(ctx, addr.String(), tlsConf, testQUICConfig)
require.NoError(t, err)
var wg sync.WaitGroup
for exchange := 0; exchange < exchanges; exchange++ {
quicStream, err := session.AcceptStream(context.Background())
require.NoError(t, err)
wg.Add(1)
go func(iter int) {
defer wg.Done()
log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close()
// Do a bunch of round trips over this stream that should work.
for msg := 0; msg < msgsPerExchange; msg++ {
clientRoundTrip(t, stream, true)
}
// And one that won't work necessarily, but shouldn't break other streams in the session.
if iter%2 == 0 {
clientRoundTrip(t, stream, false)
}
}(exchange)
}
wg.Wait()
}
func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
earlyListener, err := quic.Listen(conn, testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
serverReady.Done()
session, err := earlyListener.Accept(ctx)
require.NoError(t, err)
var wg sync.WaitGroup
for exchange := 0; exchange < exchanges; exchange++ {
quicStream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err)
wg.Add(1)
go func(iter int) {
defer wg.Done()
log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close()
// Do a bunch of round trips over this stream that should work.
for msg := 0; msg < msgsPerExchange; msg++ {
serverRoundTrip(t, stream, true)
}
// And one that won't work necessarily, but shouldn't break other streams in the session.
if iter%2 == 1 {
serverRoundTrip(t, stream, false)
}
}(exchange)
}
wg.Wait()
}
func clientRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) {
response := make([]byte, len(testMsg))
_, err := stream.Read(response)
if !mustWork {
return
}
if err != io.EOF {
require.NoError(t, err)
}
require.Equal(t, testMsg, string(response))
}
func serverRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) {
_, err := stream.Write([]byte(testMsg))
if !mustWork {
return
}
require.NoError(t, err)
}
// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
func GenerateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"argotunnel"},
}
}