diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go index ed688fea..29673a3b 100644 --- a/quic/v3/muxer.go +++ b/quic/v3/muxer.go @@ -65,12 +65,11 @@ type datagramConn struct { icmpRouter ingress.ICMPRouter metrics Metrics logger *zerolog.Logger - - datagrams chan []byte - readErrors chan error + datagrams chan []byte + readErrors chan error icmpEncoderPool sync.Pool // a pool of *packet.Encoder - icmpDecoder *packet.ICMPDecoder + icmpDecoderPool sync.Pool } func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn { @@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou return packet.NewEncoder() }, }, - icmpDecoder: packet.NewICMPDecoder(), + icmpDecoderPool: sync.Pool{ + New: func() any { + return packet.NewICMPDecoder() + }, + }, } } @@ -367,7 +370,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) { // Decode the provided ICMPDatagram as an ICMP packet rawPacket := packet.RawPacket{Data: datagram.Payload} - icmp, err := c.icmpDecoder.Decode(rawPacket) + cachedDecoder := c.icmpDecoderPool.Get() + defer c.icmpDecoderPool.Put(cachedDecoder) + decoder, ok := cachedDecoder.(*packet.ICMPDecoder) + if !ok { + c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet") + return + } + + icmp, err := decoder.Decode(rawPacket) + if err != nil { c.logger.Err(err).Msgf("unable to marshal icmp packet") return diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index 7b532ba3..90c167a6 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -4,13 +4,17 @@ import ( "bytes" "context" "errors" + "fmt" "net" "net/netip" "slices" + "sort" "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/google/gopacket/layers" "github.com/rs/zerolog" "golang.org/x/net/icmp" @@ -304,6 +308,91 @@ func TestDatagramConnServe(t *testing.T) { assertContextClosed(t, ctx, done, cancel) } +// This test exists because decoding multiple packets in parallel with the same decoder +// instances causes inteference resulting in multiple different raw packets being decoded +// as the same decoded packet. +func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + router := newMockICMPRouter() + conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + packetCount := 100 + packets := make([]*packet.ICMP, 100) + ipTemplate := "10.0.0.%d" + for i := 1; i <= packetCount; i++ { + packets[i-1] = &packet.ICMP{ + IP: &packet.IP{ + Src: netip.MustParseAddr("192.168.1.1"), + Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)), + Protocol: layers.IPProtocolICMPv4, + TTL: 20, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 25821, + Seq: 58129, + Data: []byte("test"), + }, + }, + } + } + + wg := sync.WaitGroup{} + var receivedPackets []*packet.ICMP + go func() { + for ctx.Err() == nil { + select { + case icmpPacket := <-router.recv: + receivedPackets = append(receivedPackets, icmpPacket) + wg.Done() + } + } + }() + + for _, p := range packets { + // We increment here but only decrement when receiving the packet + wg.Add(1) + go func() { + datagram := newICMPDatagram(p) + quic.send <- datagram + }() + } + + wg.Wait() + + // If there were duplicates then we won't have the same number of IPs + packetSet := make(map[netip.Addr]*packet.ICMP, 0) + for _, p := range receivedPackets { + packetSet[p.Dst] = p + } + assert.Equal(t, len(packetSet), len(packets)) + + // Sort the slice by last byte of IP address (the one we increment for each destination) + // and then check that we have one match for each packet sent + sort.Slice(receivedPackets, func(i, j int) bool { + return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3] + }) + for i, p := range receivedPackets { + assert.Equal(t, p.Dst, packets[i].Dst) + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + func TestDatagramConnServe_RegisterTwice(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn()