From c51b417e0a53162d79e60dd017c849787d7a4a99 Mon Sep 17 00:00:00 2001 From: Paolo Invernizzi Date: Mon, 1 Jul 2024 23:18:09 +0200 Subject: [PATCH] Add UDP unit tests --- protocols/udp.go | 5 +- protocols/udp_test.go | 151 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 protocols/udp_test.go diff --git a/protocols/udp.go b/protocols/udp.go index 1c252bb..9072f04 100644 --- a/protocols/udp.go +++ b/protocols/udp.go @@ -24,11 +24,14 @@ var errInvalidUDPHeader = errors.New("UDP header must be 8 bytes") // An error is returned if the headers' constraints are not respected. func UDPPacketFromIPPacket(ip IPPacket) (*UDPPacket, error) { udpHeader, err := udpHeaderFromBytes(ip.Payload()) + if err != nil { + return nil, err + } return &UDPPacket{ ipPacket: ip, header: *udpHeader, - }, err + }, nil } func udpv4PacketFromBytes(raw []byte) (*UDPPacket, error) { diff --git a/protocols/udp_test.go b/protocols/udp_test.go new file mode 100644 index 0000000..08aa3fa --- /dev/null +++ b/protocols/udp_test.go @@ -0,0 +1,151 @@ +package protocols + +import ( + "reflect" + "testing" +) + +func TestUDPPacketFromIPPacket(t *testing.T) { + tests := []struct { + name string + ipPacket IPPacket + expectedUDPPacket *UDPPacket + expectedErr error + }{ + { + name: "valid IPv4 packet with UDP payload", + ipPacket: ipv4Packet{ + payload: []byte{0x1f, 0x90, 0x23, 0xc4, 0x00, 0x10, 0x27, 0x10}, + }, + expectedUDPPacket: &UDPPacket{ + ipPacket: ipv4Packet{ + payload: []byte{0x1f, 0x90, 0x23, 0xc4, 0x00, 0x10, 0x27, 0x10}, + }, + header: udpHeader{ + sourcePort: 8080, + destinationPort: 9156, + length: 16, + checksum: 10000, + }, + }, + expectedErr: nil, + }, + { + name: "IPv4 packet with too short UDP payload", + ipPacket: ipv4Packet{ + payload: []byte{0x1f, 0x90, 0x23}, + }, + expectedUDPPacket: nil, + expectedErr: errInvalidUDPHeader, + }, + { + name: "Valid IPv6 packet with UDP payload", + ipPacket: ipv6Packet{ + payload: []byte{0x00, 0x01, 0x00, 0x02, 0x00, 0x08, 0x00, 0x00}, + }, + expectedUDPPacket: &UDPPacket{ + ipPacket: ipv6Packet{ + payload: []byte{0x00, 0x01, 0x00, 0x02, 0x00, 0x08, 0x00, 0x00}, + }, + header: udpHeader{ + sourcePort: 1, + destinationPort: 2, + length: 8, + checksum: 0, + }, + }, + expectedErr: nil, + }, + { + name: "IPv6 packet with zero length UDP payload", + ipPacket: ipv6Packet{ + payload: []byte{0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x9a, 0xbc}, + }, + expectedUDPPacket: &UDPPacket{ + ipPacket: ipv6Packet{ + payload: []byte{0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x9a, 0xbc}, + }, + header: udpHeader{ + sourcePort: 0x1234, + destinationPort: 0x5678, + length: 0, + checksum: 0x9abc, + }, + }, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + udp, err := UDPPacketFromIPPacket(tt.ipPacket) + if tt.expectedErr != err { + t.Errorf("expected error: %v - got %v", tt.expectedErr, err) + } + if tt.expectedErr == nil && !reflect.DeepEqual(udp, tt.expectedUDPPacket) { + t.Errorf("expected UDP packet to be %+v - got %+v", tt.expectedUDPPacket, udp) + } + }) + } +} + +func TestUDPHeaderFromBytes(t *testing.T) { + tests := []struct { + name string + raw []byte + expectedHeader *udpHeader + expectedErr error + }{ + { + name: "Valid UDP header", + raw: []byte{0x1f, 0x90, 0x23, 0xc4, 0x00, 0x10, 0x27, 0x10}, + expectedHeader: &udpHeader{ + sourcePort: 8080, + destinationPort: 9156, + length: 16, + checksum: 10000, + }, + expectedErr: nil, + }, + { + name: "Too short header", + raw: []byte{0x1f, 0x90, 0x23}, + expectedHeader: nil, + expectedErr: errInvalidUDPHeader, + }, + { + name: "Minimum valid header", + raw: []byte{0x00, 0x01, 0x00, 0x02, 0x00, 0x08, 0x00, 0x00}, + expectedHeader: &udpHeader{ + sourcePort: 1, + destinationPort: 2, + length: 8, + checksum: 0, + }, + expectedErr: nil, + }, + { + name: "Zero length header", + raw: []byte{0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x9a, 0xbc}, + expectedHeader: &udpHeader{ + sourcePort: 0x1234, + destinationPort: 0x5678, + length: 0, + checksum: 0x9abc, + }, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, err := udpHeaderFromBytes(tt.raw) + if tt.expectedErr != err { + t.Errorf("expected error: %v - got %v", tt.expectedErr, err) + } + if tt.expectedErr == nil && !reflect.DeepEqual(h, tt.expectedHeader) { + t.Errorf("expected header to be %+v - got %+v", tt.expectedHeader, h) + } + }) + } +}