Skip to content

Commit

Permalink
feat: add CPU native compilation and instructions
Browse files Browse the repository at this point in the history
+ add un/likely primitives (when RUST supports we should change)
+ consolidate checksum calculations (now support offload on all layers)
  • Loading branch information
kp-omer-shamash committed Jan 14, 2025
1 parent 2e261c8 commit 29fd97b
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 82 deletions.
3 changes: 3 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[target.aarch64-unknown-linux-gnu]
linker = "aarch64-linux-gnu-gcc"
runner = ["qemu-aarch64-static"] # use qemu user emulation for cargo run and test

[build]
rustflags = ["-C", "target-cpu=native"]
301 changes: 219 additions & 82 deletions lightway-core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,56 @@ use std::net::Ipv4Addr;
use std::ops;
use tracing::warn;

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

/// Check if AVX2 is available on the current CPU
#[inline(always)]
fn has_avx2() -> bool {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}

/**
* HOT/COLD path implementation until RUST adds
* https://github.com/rust-lang/rust/issues/26179
*/

#[inline]
#[cold]
fn cold() {}

#[inline]
pub(crate) fn likely(b: bool) -> bool {
if !b { cold() }
b
}

#[inline]
pub(crate) fn unlikely(b: bool) -> bool {
if b { cold() }
b
}

/// Validate if a buffer contains a valid IPv4 packet
pub(crate) fn ipv4_is_valid_packet(buf: &[u8]) -> bool {
if buf.is_empty() {
if buf.len() < 20 {
// IPv4 header is at least 20 bytes
return false;
}
let first_byte = buf[0];
let ip_version = first_byte >> 4;

ip_version == 4
}

// Structure to calculate incremental checksum
/// Structure to calculate incremental checksum
#[derive(Clone, Copy)]
struct Checksum(u16);

impl ops::Deref for Checksum {
Expand All @@ -33,122 +72,220 @@ impl ops::Sub<u16> for Checksum {
type Output = Checksum;
fn sub(self, rhs: u16) -> Checksum {
let (n, of) = self.0.overflowing_sub(rhs);
Checksum(match of {
true => n - 1,
false => n,
})
Checksum(if of { n.wrapping_sub(1) } else { n })
}
}

/// Structure to handle checksum updates when modifying IP addresses
struct ChecksumUpdate(Vec<(u16, u16)>);

impl Checksum {
// Based on RFC-1624 [Eqn. 4]
/// Update checksum when replacing one word with another
/// Based on RFC-1624 [Eqn. 4]
fn update_word(self, old_word: u16, new_word: u16) -> Self {
self - !old_word - new_word
}

/// Apply multiple checksum updates
fn update(self, updates: &ChecksumUpdate) -> Self {
updates.0.iter().fold(self, |c, x| c.update_word(x.0, x.1))
updates
.0
.iter()
.fold(self, |c, &(old, new)| c.update_word(old, new))
}
}

struct ChecksumUpdate(Vec<(u16, u16)>);
/// AVX2-accelerated checksum update (unused until we support IPv6)
// #[allow(unsafe_code)]
// #[cfg(target_arch = "x86_64")]
// #[target_feature(enable = "avx2")]
// unsafe fn update_avx2(self, updates: &ChecksumUpdate) -> Self {
// let mut sum = u32::from(self.0);

// // Process 8 words at a time using AVX2
// for chunk in updates.0.chunks(8) {
// // Pre-allocate with known size
// let mut old_words = Vec::with_capacity(8);
// let mut new_words = Vec::with_capacity(8);

// // Fill vectors with data or zeros
// for i in 0..8 {
// if let Some(&(old, new)) = chunk.get(i) {
// old_words.push(i32::from(old));
// new_words.push(i32::from(new));
// } else {
// old_words.push(0);
// new_words.push(0);
// }
// }

// // SAFETY: Vectors are guaranteed to have exactly 8 elements
// unsafe {
// // Load data into AVX2 registers
// let old_vec = _mm256_set_epi32(
// old_words[7],
// old_words[6],
// old_words[5],
// old_words[4],
// old_words[3],
// old_words[2],
// old_words[1],
// old_words[0],
// );
// let new_vec = _mm256_set_epi32(
// new_words[7],
// new_words[6],
// new_words[5],
// new_words[4],
// new_words[3],
// new_words[2],
// new_words[1],
// new_words[0],
// );

// // Compute NOT(old) + new using AVX2
// let not_old = _mm256_xor_si256(old_vec, _mm256_set1_epi32(-1));
// let sum_vec = _mm256_add_epi32(not_old, new_vec);

// // Horizontal sum
// let hadd = _mm256_hadd_epi32(sum_vec, sum_vec);
// let hadd = _mm256_hadd_epi32(hadd, hadd);

// sum = sum.wrapping_add(_mm256_extract_epi32(hadd, 0) as u32);
// }
// }

// // Fold 32-bit sum to 16 bits
// while sum > 0xFFFF {
// sum = (sum & 0xFFFF) + (sum >> 16);
// }

// Checksum(sum as u16)
// }
}

impl ChecksumUpdate {
/// Create checksum update data from IP address change
fn from_ipv4_address(old: Ipv4Addr, new: Ipv4Addr) -> Self {
let mut result = vec![];
let old: [u8; 4] = old.octets();
let new: [u8; 4] = new.octets();
for i in 0..2 {
let old_word = u16::from_be_bytes([old[i * 2], old[i * 2 + 1]]);
let new_word = u16::from_be_bytes([new[i * 2], new[i * 2 + 1]]);
result.push((old_word, new_word));
}
Self(result)
let old_bytes = old.octets();
let new_bytes = new.octets();

// Convert to u16 pairs for checksum calculation
let old_words = [
u16::from_be_bytes([old_bytes[0], old_bytes[1]]),
u16::from_be_bytes([old_bytes[2], old_bytes[3]]),
];
let new_words = [
u16::from_be_bytes([new_bytes[0], new_bytes[1]]),
u16::from_be_bytes([new_bytes[2], new_bytes[3]]),
];

Self(vec![
(old_words[0], new_words[0]),
(old_words[1], new_words[1]),
])
}
}

fn tcp_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) {
let packet = MutableTcpPacket::new(packet.payload_mut());
let Some(mut packet) = packet else {
warn!("Invalid packet size (less than Tcp header)!");
return;
};

let checksum = Checksum(packet.get_checksum());
let checksum = checksum.update(&updates);
packet.set_checksum(*checksum);
}

fn udp_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) {
let packet = MutableUdpPacket::new(packet.payload_mut());
let Some(mut packet) = packet else {
warn!("Invalid packet size (less than Udp header)!");
/// Update transport protocol checksums after IP address changes
fn update_transport_checksums(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) {
// Skip if this is not the first fragment
if packet.get_fragment_offset() != 0 {
return;
};

let checksum = Checksum(packet.get_checksum());
}

// UDP checksums are optional, and we should respect that when doing NAT
if *checksum != 0 {
let checksum = checksum.update(&updates);
packet.set_checksum(checksum.0);
match packet.get_next_level_protocol() {
IpNextHeaderProtocols::Tcp => update_tcp_checksum(packet, updates),
IpNextHeaderProtocols::Udp => update_udp_checksum(packet, updates),
IpNextHeaderProtocols::Icmp => {} // ICMP doesn't need checksum update for IP changes
protocol => {
if unlikely(true) {
warn!(protocol = ?protocol, "Unknown protocol, skipping checksum update");
}
},
}
}

fn ipv4_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) {
let checksum = Checksum(packet.get_checksum());
let checksum = checksum.update(&updates);
packet.set_checksum(*checksum);

// In case of fragmented packets, TCP/UDP header will be present only in the first fragment.
// So skip updating the checksum, if it is not the first fragment (i.e frag_offset != 0)
if 0 != packet.get_fragment_offset() {
return;
fn update_tcp_checksum(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) {
if likely(MutableTcpPacket::new(packet.payload_mut()).is_some()) {
let mut tcp_packet = MutableTcpPacket::new(packet.payload_mut()).unwrap();
let checksum = tcp_packet.get_checksum();
// Only update if checksum is present (not 0)
if checksum != 0 {
let checksum = Checksum(checksum).update(&updates);
tcp_packet.set_checksum(*checksum);
}
} else {
warn!("Invalid packet size (less than TCP header)!");
}
}

let transport_protocol = packet.get_next_level_protocol();
match transport_protocol {
IpNextHeaderProtocols::Tcp => tcp_adjust_packet_checksum(packet, updates),
IpNextHeaderProtocols::Udp => udp_adjust_packet_checksum(packet, updates),
IpNextHeaderProtocols::Icmp => {}
protocol => {
warn!(protocol = ?protocol, "Unknown protocol, skipping checksum adjust")
fn update_udp_checksum(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) {
if likely(MutableUdpPacket::new(packet.payload_mut()).is_some()) {
let mut udp_packet = MutableUdpPacket::new(packet.payload_mut()).unwrap();
let checksum = udp_packet.get_checksum();
// Only update if checksum is present (not 0)
if checksum != 0 {
let checksum = Checksum(checksum).update(&updates);
udp_packet.set_checksum(*checksum);
}
} else {
warn!("Invalid packet size (less than UDP header)!");
}
}

/// Utility function to update source ip address in ipv4 packet buffer
/// Nop if buf is not a valid IPv4 packet
pub fn ipv4_update_source(buf: &mut [u8], ip: Ipv4Addr) {
let packet = MutableIpv4Packet::new(buf);
let Some(mut packet) = packet else {
warn!("Invalid packet size (less than Ipv4 header)!");
#[derive(Clone, Copy)]
enum IpField {
Source,
Destination,
}

/// NOTE: the field is compile-time known, so gets optimized, this is for better maintanance
#[inline(always)]
fn ipv4_update_field(buf: &mut [u8], new_ip: Ipv4Addr, field: IpField) {
let Some(mut packet) = MutableIpv4Packet::new(buf) else {
if unlikely(true) {
warn!("Failed to create IPv4 packet!");
}
return;
};

let old = packet.get_source();
// Set new source only after getting old source ip address
packet.set_source(ip);

ipv4_adjust_packet_checksum(packet, ChecksumUpdate::from_ipv4_address(old, ip));
}
// Get old IP before updating
let old_ip = match field {
IpField::Source => packet.get_source(),
IpField::Destination => packet.get_destination(),
};

/// Utility function to update destination ip address in ipv4 packet buffer
/// Nop if buf is not a valid IPv4 packet
pub fn ipv4_update_destination(buf: &mut [u8], ip: Ipv4Addr) {
let packet = MutableIpv4Packet::new(buf);
let Some(mut packet) = packet else {
warn!("Invalid packet size (less than Ipv4 header)!");
return;
// Update IP field
match field {
IpField::Source => packet.set_source(new_ip),
IpField::Destination => packet.set_destination(new_ip),
};

let old = packet.get_destination();
// Set new destination only after getting old destination ip address
packet.set_destination(ip);
// Update checksums
let updates = ChecksumUpdate::from_ipv4_address(old_ip, new_ip);
let checksum = packet.get_checksum();
if checksum != 0 {
let checksum = Checksum(checksum).update(&updates);
packet.set_checksum(*checksum);
}

// Update transport protocol checksums
update_transport_checksums(&mut packet, updates);
}

/// Update source IP address in an IPv4 packet
#[inline]
pub fn ipv4_update_source(buf: &mut [u8], new_ip: Ipv4Addr) {
ipv4_update_field(buf, new_ip, IpField::Source)
}

ipv4_adjust_packet_checksum(packet, ChecksumUpdate::from_ipv4_address(old, ip));
/// Update destination IP address in an IPv4 packet
#[inline]
pub fn ipv4_update_destination(buf: &mut [u8], new_ip: Ipv4Addr) {
ipv4_update_field(buf, new_ip, IpField::Destination)
}

/// Clamp TCP MSS option if present in a TCP SYN packet
pub fn tcp_clamp_mss(pkt: &mut [u8], mss: u16) -> Option<u16> {
let mut ipv4_packet = MutableIpv4Packet::new(pkt)?;

Expand Down

0 comments on commit 29fd97b

Please sign in to comment.