Skip to content

Commit

Permalink
Merge pull request #1 from ti-mo/attribute-codec
Browse files Browse the repository at this point in the history
Switch to AttributeEncoder/Decoder
  • Loading branch information
ti-mo authored Dec 19, 2019
2 parents f20d5b6 + 5ebce80 commit 716c04d
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 123 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ bench:
bench-integration:
go test -bench=. -tags=integration -exec sudo ./...

cover: cover.out
cover.out: $(SOURCES)
.PHONY: cover
cover:
go test -coverprofile=cover.out -covermode=atomic ./...
# Remove coverage output from files generated by Stringer.
sed -i '/_string.go/d' cover.out
Expand Down
187 changes: 119 additions & 68 deletions attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,33 @@ import (
"fmt"

"github.com/mdlayher/netlink"
"github.com/pkg/errors"
"golang.org/x/sys/unix"
)

// NewAttributeDecoder instantiates a new netlink.AttributeDecoder
// configured with a Big Endian byte order.
func NewAttributeDecoder(b []byte) (*netlink.AttributeDecoder, error) {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
return nil, err
}

// All Netfilter attribute payloads are big-endian. (network byte order)
ad.ByteOrder = binary.BigEndian

return ad, nil
}

// NewAttributeDecoder instantiates a new netlink.AttributeEncoder
// configured with a Big Endian byte order.
func NewAttributeEncoder() *netlink.AttributeEncoder {
ae := netlink.NewAttributeEncoder()

// All Netfilter attribute payloads are big-endian. (network byte order)
ae.ByteOrder = binary.BigEndian

return ae
}

// An Attribute is a copy of a netlink.Attribute that can be nested.
type Attribute struct {

Expand Down Expand Up @@ -138,99 +161,127 @@ func Uint64Bytes(u uint64) []byte {
return d
}

// unmarshalAttributes returns an array of netfilter.Attributes decoded from
// a byte array. This byte array should be taken from the netlink.Message's
// Data payload after the nfHeaderLen offset.
func unmarshalAttributes(b []byte) ([]Attribute, error) {

// Obtain a list of parsed netlink attributes possibly holding
// nested Netfilter attributes in their binary Data field.
attrs, err := netlink.UnmarshalAttributes(b)
if err != nil {
return nil, errors.Wrap(err, errWrapNetlinkUnmarshalAttrs)
}
// decode fills the Attribute's Children field with Attributes
// obtained by exhausting ad.
func (a *Attribute) decode(ad *netlink.AttributeDecoder) error {

var ra []Attribute

// Only allocate backing array when there are netlink attributes to decode.
if len(attrs) != 0 {
ra = make([]Attribute, 0, len(attrs))
}

// Wrap all netlink.Attributes into netfilter.Attributes to support nesting
for _, nla := range attrs {
for ad.Next() {

// Copy the netlink attribute's fields into the netfilter attribute.
nfa := Attribute{
// Only consider the rightmost 14 bits for Type
Type: nla.Type & ^(uint16(unix.NLA_F_NESTED) | uint16(unix.NLA_F_NET_BYTEORDER)),
Data: nla.Data,
// Only consider the rightmost 14 bits for Type.
// ad.Type implicitly masks the Nested and NetByteOrder bits.
Type: ad.Type(),
Data: ad.Bytes(),
}

// Boolean flags extracted from the two leftmost bits of Type
nfa.Nested = (nla.Type & uint16(unix.NLA_F_NESTED)) != 0
nfa.NetByteOrder = (nla.Type & uint16(unix.NLA_F_NET_BYTEORDER)) != 0
// Boolean flags extracted from the two leftmost bits of Type.
nfa.Nested = ad.TypeFlags()&netlink.Nested != 0
nfa.NetByteOrder = ad.TypeFlags()&netlink.NetByteOrder != 0

if nfa.NetByteOrder && nfa.Nested {
return nil, errInvalidAttributeFlags
return errInvalidAttributeFlags
}

// Unmarshal recursively if the netlink Nested flag is set
// Unmarshal recursively if the netlink Nested flag is set.
if nfa.Nested {
if nfa.Children, err = unmarshalAttributes(nla.Data); err != nil {
return nil, err
}
ad.Nested(nfa.decode)
}

ra = append(ra, nfa)
a.Children = append(a.Children, nfa)
}

return ra, nil
return ad.Err()
}

// marshalAttributes marshals a nested attribute structure into a byte slice.
// This byte slice can then be copied into a netlink.Message's Data field after
// the nfHeaderLen offset.
func marshalAttributes(attrs []Attribute) ([]byte, error) {
// encode returns a function that takes an AttributeEncoder and returns error.
// This function can be passed to AttributeEncoder.Nested for recursively
// encoding Attributes.
func (a *Attribute) encode(attrs []Attribute) func(*netlink.AttributeEncoder) error {

return func(ae *netlink.AttributeEncoder) error {

// netlink.Attribute to use as scratch buffer, requires a single allocation
nla := netlink.Attribute{}
for _, nfa := range attrs {

// Output array, initialized to the length of the input array
ra := make([]netlink.Attribute, 0, len(attrs))
if nfa.NetByteOrder && nfa.Nested {
return errInvalidAttributeFlags
}

for _, nfa := range attrs {
if nfa.Nested {
ae.Nested(nfa.Type, nfa.encode(nfa.Children))
continue
}

if nfa.NetByteOrder && nfa.Nested {
return nil, errInvalidAttributeFlags
// Manually set the NetByteOrder flag, since ae.Bytes() can't.
if nfa.NetByteOrder {
nfa.Type |= netlink.NetByteOrder
}
ae.Bytes(nfa.Type, nfa.Data)
}

// Save nested or byte order flags back to the netlink.Attribute's
// Type field to include it in the marshaling operation
nla.Type = nfa.Type
return nil
}
}

switch {
case nfa.Nested:
nla.Type = nla.Type | unix.NLA_F_NESTED
case nfa.NetByteOrder:
nla.Type = nla.Type | unix.NLA_F_NET_BYTEORDER
}
// decodeAttributes returns an array of netfilter.Attributes decoded from
// a byte array. This byte array should be taken from the netlink.Message's
// Data payload after the nfHeaderLen offset.
func decodeAttributes(ad *netlink.AttributeDecoder) ([]Attribute, error) {

// Recursively marshal the attribute's children
if nfa.Nested {
nfnab, err := marshalAttributes(nfa.Children)
if err != nil {
return nil, err
}
// Use the Children element of the Attribute to decode into.
// Attribute already has nested decoding implemented on the type.
var a Attribute

nla.Data = nfnab
} else {
nla.Data = nfa.Data
}
// Pre-allocate backing array when there are netlink attributes to decode.
if ad.Len() != 0 {
a.Children = make([]Attribute, 0, ad.Len())
}

// Catch any errors encountered parsing netfilter structures.
if err := a.decode(ad); err != nil {
return nil, err
}

return a.Children, nil
}

// encodeAttributes encodes a list of Attributes into the given netlink.AttributeEncoder.
func encodeAttributes(ae *netlink.AttributeEncoder, attrs []Attribute) error {

if ae == nil {
return errNilAttributeEncoder
}

attr := Attribute{}
return attr.encode(attrs)(ae)
}

// MarshalAttributes marshals a nested attribute structure into a byte slice.
// This byte slice can then be copied into a netlink.Message's Data field after
// the nfHeaderLen offset.
func MarshalAttributes(attrs []Attribute) ([]byte, error) {

ae := NewAttributeEncoder()

if err := encodeAttributes(ae, attrs); err != nil {
return nil, err
}

ra = append(ra, nla)
b, err := ae.Encode()
if err != nil {
return nil, err
}

return b, nil
}

// UnmarshalAttributes unmarshals a byte slice into a list of Attributes.
func UnmarshalAttributes(b []byte) ([]Attribute, error) {

ad, err := NewAttributeDecoder(b)
if err != nil {
return nil, err
}

// Marshal all Netfilter attributes into binary representation of Netlink attributes
return netlink.MarshalAttributes(ra)
return decodeAttributes(ad)
}
55 changes: 30 additions & 25 deletions attribute_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package netfilter

import (
"errors"
"strings"
"testing"

Expand Down Expand Up @@ -141,7 +142,7 @@ func TestAttributeMarshalAttributes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

b, err := marshalAttributes(tt.attrs)
b, err := MarshalAttributes(tt.attrs)
if err != nil {
t.Fatalf("unexpected marshal error: %v", err)
}
Expand Down Expand Up @@ -194,7 +195,7 @@ func TestAttributeMarshalErrors(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := marshalAttributes(tt.attrs)
_, err := MarshalAttributes(tt.attrs)
require.Error(t, err, "marshal must error")

if tt.err != nil {
Expand All @@ -209,17 +210,20 @@ func TestAttributeMarshalErrors(t *testing.T) {
}
}

func TestAttributeUnmarshalErrors(t *testing.T) {
func TestAttributeDecoderErrors(t *testing.T) {
tests := []struct {
name string
b []byte
err error
errWrap string
}{
{
name: "netlink unmarshal error",
b: []byte{1},
errWrap: errWrapNetlinkUnmarshalAttrs,
name: "invalid attribute flags on top-level attribute",
b: []byte{
8, 0, 0, 192, // 192 = nested + netByteOrder
0, 0, 0, 0,
},
err: errInvalidAttributeFlags,
},
{
name: "invalid attribute flags on nested attribute",
Expand All @@ -230,27 +234,19 @@ func TestAttributeUnmarshalErrors(t *testing.T) {
},
err: errInvalidAttributeFlags,
},
{
name: "decoding invalid attribute",
b: []byte{4, 0, 0},
err: errors.New("invalid attribute; length too short or too large"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := unmarshalAttributes(tt.b)

if err == nil {
t.Fatal("unmarshal did not error")
}

if tt.err != nil {
if want, got := tt.err, err; want != got {
t.Fatalf("unexpected error:\n- want: %v\n- got: %v",
want, got.Error())
}
} else if tt.errWrap != "" {
if !strings.HasPrefix(err.Error(), tt.errWrap+":") {
t.Fatalf("unexpected wrapped error:\n- expected prefix: %v\n- error string: %v",
tt.errWrap, err)
}
}
_, err := UnmarshalAttributes(tt.b)
require.Error(t, err)
require.Error(t, tt.err)
require.EqualError(t, err, tt.err.Error())
})
}
}
Expand Down Expand Up @@ -389,19 +385,28 @@ func TestAttributeMarshalTwoWay(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

ad, err := NewAttributeDecoder(tt.b)
if err != nil {
t.Fatal("unexpected error creating AttributeDecoder:", err)
}

// Unmarshal binary content into nested structures
attrs, err := unmarshalAttributes(tt.b)
attrs, err := decodeAttributes(ad)
require.NoError(t, err)

assert.Empty(t, cmp.Diff(tt.attrs, attrs))

var b []byte

// Attempt re-marshal into binary form
b, err = marshalAttributes(tt.attrs)
b, err = MarshalAttributes(tt.attrs)
require.NoError(t, err)

assert.Empty(t, cmp.Diff(tt.b, b))
})
}
}

func TestErrors(t *testing.T) {
assert.EqualError(t, encodeAttributes(nil, nil), errNilAttributeEncoder.Error())
}
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestConnQueryMulticast(t *testing.T) {
func TestConnReceive(t *testing.T) {

// Inject a message directly into the nltest connection
connEcho.conn.Send(nlMsgReqAck)
_, _ = connEcho.conn.Send(nlMsgReqAck)

// Drain the socket
_, err := connEcho.Receive()
Expand Down
6 changes: 2 additions & 4 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ import (
"errors"
)

const (
errWrapNetlinkUnmarshalAttrs = "error unmarshaling netlink attributes"
)

var (
// errInvalidAttributeFlags specifies if an Attribute's flag configuration is invalid.
// From a comment in Linux/include/uapi/linux/netlink.h, Nested and NetByteOrder are mutually exclusive.
Expand All @@ -18,4 +14,6 @@ var (
errConnIsMulticast = errors.New("Conn is attached to one or more multicast groups and can no longer be used for bidirectional traffic")

errNoMulticastGroups = errors.New("need one or more multicast groups to join")

errNilAttributeEncoder = errors.New("given AttributeEncoder is nil")
)
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ go 1.12

require (
github.com/google/go-cmp v0.3.1
github.com/mdlayher/netlink v1.0.0
github.com/mdlayher/netlink v1.0.1-0.20191210152442-a1644773bc99
github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.3.0
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449
)
Loading

0 comments on commit 716c04d

Please sign in to comment.