Skip to content

Commit

Permalink
bridge: simplify ip_mc_check_igmp() and ipv6_mc_check_mld() calls
Browse files Browse the repository at this point in the history
This patch refactors ip_mc_check_igmp(), ipv6_mc_check_mld() and
their callers (more precisely, the Linux bridge) to not rely on
the skb_trimmed parameter anymore.

An skb with its tail trimmed to the IP packet length was initially
introduced for the following three reasons:

1) To be able to verify the ICMPv6 checksum.
2) To be able to distinguish the version of an IGMP or MLD query.
   They are distinguishable only by their size.
3) To avoid parsing data for an IGMPv3 or MLDv2 report that is
   beyond the IP packet but still within the skb.

The first case still uses a cloned and potentially trimmed skb to
verfiy. However, there is no need to propagate it to the caller.
For the second and third case explicit IP packet length checks were
added.

This hopefully makes ip_mc_check_igmp() and ipv6_mc_check_mld() easier
to read and verfiy, as well as easier to use.

Signed-off-by: Linus Lüssing <linus.luessing@c0d3.blue>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
T-X authored and davem330 committed Jan 23, 2019
1 parent 6679cf0 commit ba5ea61
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 72 deletions.
11 changes: 10 additions & 1 deletion include/linux/igmp.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <linux/skbuff.h>
#include <linux/timer.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/refcount.h>
#include <uapi/linux/igmp.h>

Expand Down Expand Up @@ -106,6 +107,14 @@ struct ip_mc_list {
#define IGMPV3_QQIC(value) IGMPV3_EXP(0x80, 4, 3, value)
#define IGMPV3_MRC(value) IGMPV3_EXP(0x80, 4, 3, value)

static inline int ip_mc_may_pull(struct sk_buff *skb, unsigned int len)
{
if (skb_transport_offset(skb) + ip_transport_len(skb) < len)
return -EINVAL;

return pskb_may_pull(skb, len);
}

extern int ip_check_mc_rcu(struct in_device *dev, __be32 mc_addr, __be32 src_addr, u8 proto);
extern int igmp_rcv(struct sk_buff *);
extern int ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr);
Expand All @@ -130,6 +139,6 @@ extern void ip_mc_unmap(struct in_device *);
extern void ip_mc_remap(struct in_device *);
extern void ip_mc_dec_group(struct in_device *in_dev, __be32 addr);
extern void ip_mc_inc_group(struct in_device *in_dev, __be32 addr);
int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed);
int ip_mc_check_igmp(struct sk_buff *skb);

#endif
5 changes: 5 additions & 0 deletions include/linux/ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,9 @@ static inline struct iphdr *ipip_hdr(const struct sk_buff *skb)
{
return (struct iphdr *)skb_transport_header(skb);
}

static inline unsigned int ip_transport_len(const struct sk_buff *skb)
{
return ntohs(ip_hdr(skb)->tot_len) - skb_network_header_len(skb);
}
#endif /* _LINUX_IP_H */
6 changes: 6 additions & 0 deletions include/linux/ipv6.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ static inline struct ipv6hdr *ipipv6_hdr(const struct sk_buff *skb)
return (struct ipv6hdr *)skb_transport_header(skb);
}

static inline unsigned int ipv6_transport_len(const struct sk_buff *skb)
{
return ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr) -
skb_network_header_len(skb);
}

/*
This structure contains results of exthdrs parsing
as offsets from skb->nh.
Expand Down
12 changes: 11 additions & 1 deletion include/net/addrconf.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct prefix_info {
struct in6_addr prefix;
};

#include <linux/ipv6.h>
#include <linux/netdevice.h>
#include <net/if_inet6.h>
#include <net/ipv6.h>
Expand Down Expand Up @@ -201,6 +202,15 @@ u32 ipv6_addr_label(struct net *net, const struct in6_addr *addr,
/*
* multicast prototypes (mcast.c)
*/
static inline int ipv6_mc_may_pull(struct sk_buff *skb,
unsigned int len)
{
if (skb_transport_offset(skb) + ipv6_transport_len(skb) < len)
return -EINVAL;

return pskb_may_pull(skb, len);
}

int ipv6_sock_mc_join(struct sock *sk, int ifindex,
const struct in6_addr *addr);
int ipv6_sock_mc_drop(struct sock *sk, int ifindex,
Expand All @@ -219,7 +229,7 @@ void ipv6_mc_unmap(struct inet6_dev *idev);
void ipv6_mc_remap(struct inet6_dev *idev);
void ipv6_mc_init_dev(struct inet6_dev *idev);
void ipv6_mc_destroy_dev(struct inet6_dev *idev);
int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed);
int ipv6_mc_check_mld(struct sk_buff *skb);
void addrconf_dad_failure(struct sk_buff *skb, struct inet6_ifaddr *ifp);

bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group,
Expand Down
4 changes: 2 additions & 2 deletions net/batman-adv/multicast.c
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ static void batadv_mcast_mla_update(struct work_struct *work)
*/
static bool batadv_mcast_is_report_ipv4(struct sk_buff *skb)
{
if (ip_mc_check_igmp(skb, NULL) < 0)
if (ip_mc_check_igmp(skb) < 0)
return false;

switch (igmp_hdr(skb)->type) {
Expand Down Expand Up @@ -741,7 +741,7 @@ static int batadv_mcast_forw_mode_check_ipv4(struct batadv_priv *bat_priv,
*/
static bool batadv_mcast_is_report_ipv6(struct sk_buff *skb)
{
if (ipv6_mc_check_mld(skb, NULL) < 0)
if (ipv6_mc_check_mld(skb) < 0)
return false;

switch (icmp6_hdr(skb)->icmp6_type) {
Expand Down
57 changes: 28 additions & 29 deletions net/bridge/br_multicast.c
Original file line number Diff line number Diff line change
Expand Up @@ -938,15 +938,15 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br,

for (i = 0; i < num; i++) {
len += sizeof(*grec);
if (!pskb_may_pull(skb, len))
if (!ip_mc_may_pull(skb, len))
return -EINVAL;

grec = (void *)(skb->data + len - sizeof(*grec));
group = grec->grec_mca;
type = grec->grec_type;

len += ntohs(grec->grec_nsrcs) * 4;
if (!pskb_may_pull(skb, len))
if (!ip_mc_may_pull(skb, len))
return -EINVAL;

/* We treat this as an IGMPv2 report for now. */
Expand Down Expand Up @@ -985,15 +985,17 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
unsigned int nsrcs_offset;
const unsigned char *src;
struct icmp6hdr *icmp6h;
struct mld2_grec *grec;
unsigned int grec_len;
int i;
int len;
int num;
int err = 0;

if (!pskb_may_pull(skb, sizeof(*icmp6h)))
if (!ipv6_mc_may_pull(skb, sizeof(*icmp6h)))
return -EINVAL;

icmp6h = icmp6_hdr(skb);
Expand All @@ -1003,21 +1005,25 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
for (i = 0; i < num; i++) {
__be16 *nsrcs, _nsrcs;

nsrcs = skb_header_pointer(skb,
len + offsetof(struct mld2_grec,
grec_nsrcs),
nsrcs_offset = len + offsetof(struct mld2_grec, grec_nsrcs);

if (skb_transport_offset(skb) + ipv6_transport_len(skb) <
nsrcs_offset + sizeof(_nsrcs))
return -EINVAL;

nsrcs = skb_header_pointer(skb, nsrcs_offset,
sizeof(_nsrcs), &_nsrcs);
if (!nsrcs)
return -EINVAL;

if (!pskb_may_pull(skb,
len + sizeof(*grec) +
sizeof(struct in6_addr) * ntohs(*nsrcs)))
grec_len = sizeof(*grec) +
sizeof(struct in6_addr) * ntohs(*nsrcs);

if (!ipv6_mc_may_pull(skb, len + grec_len))
return -EINVAL;

grec = (struct mld2_grec *)(skb->data + len);
len += sizeof(*grec) +
sizeof(struct in6_addr) * ntohs(*nsrcs);
len += grec_len;

/* We treat these as MLDv1 reports for now. */
switch (grec->grec_type) {
Expand Down Expand Up @@ -1219,6 +1225,7 @@ static void br_ip4_multicast_query(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
unsigned int transport_len = ip_transport_len(skb);
const struct iphdr *iph = ip_hdr(skb);
struct igmphdr *ih = igmp_hdr(skb);
struct net_bridge_mdb_entry *mp;
Expand All @@ -1228,7 +1235,6 @@ static void br_ip4_multicast_query(struct net_bridge *br,
struct br_ip saddr;
unsigned long max_delay;
unsigned long now = jiffies;
unsigned int offset = skb_transport_offset(skb);
__be32 group;

spin_lock(&br->multicast_lock);
Expand All @@ -1238,14 +1244,14 @@ static void br_ip4_multicast_query(struct net_bridge *br,

group = ih->group;

if (skb->len == offset + sizeof(*ih)) {
if (transport_len == sizeof(*ih)) {
max_delay = ih->code * (HZ / IGMP_TIMER_SCALE);

if (!max_delay) {
max_delay = 10 * HZ;
group = 0;
}
} else if (skb->len >= offset + sizeof(*ih3)) {
} else if (transport_len >= sizeof(*ih3)) {
ih3 = igmpv3_query_hdr(skb);
if (ih3->nsrcs)
goto out;
Expand Down Expand Up @@ -1296,6 +1302,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
unsigned int transport_len = ipv6_transport_len(skb);
const struct ipv6hdr *ip6h = ipv6_hdr(skb);
struct mld_msg *mld;
struct net_bridge_mdb_entry *mp;
Expand All @@ -1315,7 +1322,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
(port && port->state == BR_STATE_DISABLED))
goto out;

if (skb->len == offset + sizeof(*mld)) {
if (transport_len == sizeof(*mld)) {
if (!pskb_may_pull(skb, offset + sizeof(*mld))) {
err = -EINVAL;
goto out;
Expand Down Expand Up @@ -1581,12 +1588,11 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
struct sk_buff *skb_trimmed = NULL;
const unsigned char *src;
struct igmphdr *ih;
int err;

err = ip_mc_check_igmp(skb, &skb_trimmed);
err = ip_mc_check_igmp(skb);

if (err == -ENOMSG) {
if (!ipv4_is_local_multicast(ip_hdr(skb)->daddr)) {
Expand All @@ -1612,19 +1618,16 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
err = br_ip4_multicast_add_group(br, port, ih->group, vid, src);
break;
case IGMPV3_HOST_MEMBERSHIP_REPORT:
err = br_ip4_multicast_igmp3_report(br, port, skb_trimmed, vid);
err = br_ip4_multicast_igmp3_report(br, port, skb, vid);
break;
case IGMP_HOST_MEMBERSHIP_QUERY:
br_ip4_multicast_query(br, port, skb_trimmed, vid);
br_ip4_multicast_query(br, port, skb, vid);
break;
case IGMP_HOST_LEAVE_MESSAGE:
br_ip4_multicast_leave_group(br, port, ih->group, vid, src);
break;
}

if (skb_trimmed && skb_trimmed != skb)
kfree_skb(skb_trimmed);

br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp,
BR_MCAST_DIR_RX);

Expand All @@ -1637,12 +1640,11 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
struct sk_buff *skb_trimmed = NULL;
const unsigned char *src;
struct mld_msg *mld;
int err;

err = ipv6_mc_check_mld(skb, &skb_trimmed);
err = ipv6_mc_check_mld(skb);

if (err == -ENOMSG) {
if (!ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr))
Expand All @@ -1664,20 +1666,17 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
src);
break;
case ICMPV6_MLD2_REPORT:
err = br_ip6_multicast_mld2_report(br, port, skb_trimmed, vid);
err = br_ip6_multicast_mld2_report(br, port, skb, vid);
break;
case ICMPV6_MGM_QUERY:
err = br_ip6_multicast_query(br, port, skb_trimmed, vid);
err = br_ip6_multicast_query(br, port, skb, vid);
break;
case ICMPV6_MGM_REDUCTION:
src = eth_hdr(skb)->h_source;
br_ip6_multicast_leave_group(br, port, &mld->mld_mca, vid, src);
break;
}

if (skb_trimmed && skb_trimmed != skb)
kfree_skb(skb_trimmed);

br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp,
BR_MCAST_DIR_RX);

Expand Down
23 changes: 4 additions & 19 deletions net/ipv4/igmp.c
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,7 @@ static inline __sum16 ip_mc_validate_checksum(struct sk_buff *skb)
return skb_checksum_simple_validate(skb);
}

static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
static int __ip_mc_check_igmp(struct sk_buff *skb)

{
struct sk_buff *skb_chk;
Expand All @@ -1566,16 +1566,10 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
if (ret)
goto err;

if (skb_trimmed)
*skb_trimmed = skb_chk;
/* free now unneeded clone */
else if (skb_chk != skb)
kfree_skb(skb_chk);

ret = 0;

err:
if (ret && skb_chk && skb_chk != skb)
if (skb_chk && skb_chk != skb)
kfree_skb(skb_chk);

return ret;
Expand All @@ -1584,7 +1578,6 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
/**
* ip_mc_check_igmp - checks whether this is a sane IGMP packet
* @skb: the skb to validate
* @skb_trimmed: to store an skb pointer trimmed to IPv4 packet tail (optional)
*
* Checks whether an IPv4 packet is a valid IGMP packet. If so sets
* skb transport header accordingly and returns zero.
Expand All @@ -1594,18 +1587,10 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
* -ENOMSG: IP header validation succeeded but it is not an IGMP packet.
* -ENOMEM: A memory allocation failure happened.
*
* Optionally, an skb pointer might be provided via skb_trimmed (or set it
* to NULL): After parsing an IGMP packet successfully it will point to
* an skb which has its tail aligned to the IP packet end. This might
* either be the originally provided skb or a trimmed, cloned version if
* the skb frame had data beyond the IP packet. A cloned skb allows us
* to leave the original skb and its full frame unchanged (which might be
* desirable for layer 2 frame jugglers).
*
* Caller needs to set the skb network header and free any returned skb if it
* differs from the provided skb.
*/
int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
int ip_mc_check_igmp(struct sk_buff *skb)
{
int ret = ip_mc_check_iphdr(skb);

Expand All @@ -1615,7 +1600,7 @@ int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
if (ip_hdr(skb)->protocol != IPPROTO_IGMP)
return -ENOMSG;

return __ip_mc_check_igmp(skb, skb_trimmed);
return __ip_mc_check_igmp(skb);
}
EXPORT_SYMBOL(ip_mc_check_igmp);

Expand Down
Loading

0 comments on commit ba5ea61

Please sign in to comment.