Skip to content

Commit

Permalink
Enable some numeric cast releated clippy lints and fix them in the co…
Browse files Browse the repository at this point in the history
…de base

* clippy::cast_possible_wrap
* clippy::cast_possible_truncation
* clippy::cast_sign_loss

These lints can point to serious problems if you hit one of the edge
cases in low level unsafe/byte shuffling code

This is a reaction to
https://media.defcon.org/DEF%20CON%2032/DEF%20CON%2032%20presentations/DEF%20CON%2032%20-%20Paul%20Gerste%20-%20SQL%20Injection%20Isn't%20Dead%20Smuggling%20Queries%20at%20the%20Protocol%20Level.pdf

It fixes several places that could be possibly exploited by specially
crafted values.
  • Loading branch information
weiznich committed Aug 15, 2024
1 parent 029a8d4 commit 88b9074
Show file tree
Hide file tree
Showing 39 changed files with 343 additions and 149 deletions.
5 changes: 4 additions & 1 deletion diesel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@
clippy::enum_glob_use,
clippy::if_not_else,
clippy::items_after_statements,
clippy::used_underscore_binding
clippy::used_underscore_binding,
clippy::cast_possible_wrap,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
#![deny(unsafe_code)]
#![cfg_attr(test, allow(clippy::map_unwrap_or, clippy::unwrap_used))]
Expand Down
20 changes: 15 additions & 5 deletions diesel/src/mysql/connection/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ impl Clone for BindData {
// written. At the time of writing this comment, the `BindData::bind_for_truncated_data`
// function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding
// invariant.
std::slice::from_raw_parts(ptr.as_ptr(), self.length as usize)
std::slice::from_raw_parts(
ptr.as_ptr(),
self.length.try_into().expect("usize is at least 32bit"),
)
};
let mut vec = slice.to_owned();
let ptr = NonNull::new(vec.as_mut_ptr());
Expand Down Expand Up @@ -415,7 +418,10 @@ impl BindData {
// written. At the time of writing this comment, the `BindData::bind_for_truncated_data`
// function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding
// invariant.
std::slice::from_raw_parts(data.as_ptr(), self.length as usize)
std::slice::from_raw_parts(
data.as_ptr(),
self.length.try_into().expect("Usize is at least 32 bit"),
)
};
Some(MysqlValue::new_internal(slice, tpe))
}
Expand All @@ -428,7 +434,10 @@ impl BindData {
fn update_buffer_length(&mut self) {
use std::cmp::min;

let actual_bytes_in_buffer = min(self.capacity, self.length as usize);
let actual_bytes_in_buffer = min(
self.capacity,
self.length.try_into().expect("Usize is at least 32 bit"),
);
self.length = actual_bytes_in_buffer as libc::c_ulong;
}

Expand Down Expand Up @@ -474,7 +483,8 @@ impl BindData {
self.bytes = None;

let offset = self.capacity;
let truncated_amount = self.length as usize - offset;
let truncated_amount =
usize::try_from(self.length).expect("Usize is at least 32 bit") - offset;

debug_assert!(
truncated_amount > 0,
Expand Down Expand Up @@ -504,7 +514,7 @@ impl BindData {
// offset is zero here as we don't have a buffer yet
// we know the requested length here so we can just request
// the correct size
let mut vec = vec![0_u8; self.length as usize];
let mut vec = vec![0_u8; self.length.try_into().expect("usize is at least 32 bit")];
self.capacity = vec.capacity();
self.bytes = NonNull::new(vec.as_mut_ptr());
mem::forget(vec);
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/mysql/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ impl Connection for MysqlConnection {
// we have not called result yet, so calling `execute` is
// fine
let stmt_use = unsafe { stmt.execute() }?;
Ok(stmt_use.affected_rows())
stmt_use.affected_rows()
}),
&mut self.transaction_state,
&mut self.instrumentation,
Expand Down
28 changes: 18 additions & 10 deletions diesel/src/mysql/connection/stmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ pub(super) struct StatementUse<'a> {
}

impl<'a> StatementUse<'a> {
pub(in crate::mysql::connection) fn affected_rows(&self) -> usize {
pub(in crate::mysql::connection) fn affected_rows(&self) -> QueryResult<usize> {
let affected_rows = unsafe { ffi::mysql_stmt_affected_rows(self.inner.stmt.as_ptr()) };
affected_rows as usize
affected_rows
.try_into()
.map_err(|e| Error::DeserializationError(Box::new(e)))
}

/// This function should be called after `execute` only
Expand All @@ -167,14 +169,19 @@ impl<'a> StatementUse<'a> {

pub(super) fn populate_row_buffers(&self, binds: &mut OutputBinds) -> QueryResult<Option<()>> {
let next_row_result = unsafe { ffi::mysql_stmt_fetch(self.inner.stmt.as_ptr()) };
match next_row_result as libc::c_uint {
ffi::MYSQL_NO_DATA => Ok(None),
ffi::MYSQL_DATA_TRUNCATED => binds.populate_dynamic_buffers(self).map(Some),
0 => {
binds.update_buffer_lengths();
Ok(Some(()))
if next_row_result < 0 {
self.inner.did_an_error_occur().map(Some)
} else {
#[allow(clippy::cast_sign_loss)] // that's how it's supposed to be based on the API
match next_row_result as libc::c_uint {
ffi::MYSQL_NO_DATA => Ok(None),
ffi::MYSQL_DATA_TRUNCATED => binds.populate_dynamic_buffers(self).map(Some),
0 => {
binds.update_buffer_lengths();
Ok(Some(()))
}
_error => self.inner.did_an_error_occur().map(Some),
}
_error => self.inner.did_an_error_occur().map(Some),
}
}

Expand All @@ -187,7 +194,8 @@ impl<'a> StatementUse<'a> {
ffi::mysql_stmt_fetch_column(
self.inner.stmt.as_ptr(),
bind,
idx as libc::c_uint,
idx.try_into()
.map_err(|e| Error::DeserializationError(Box::new(e)))?,
offset as libc::c_ulong,
);
self.inner.did_an_error_occur()
Expand Down
32 changes: 18 additions & 14 deletions diesel/src/mysql/types/date_and_time/chrono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl FromSql<Datetime, Mysql> for NaiveDateTime {
impl ToSql<Timestamp, Mysql> for NaiveDateTime {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: self.hour() as libc::c_uint,
Expand All @@ -48,16 +48,16 @@ impl FromSql<Timestamp, Mysql> for NaiveDateTime {
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let mysql_time = <MysqlTime as FromSql<Timestamp, Mysql>>::from_sql(bytes)?;

NaiveDate::from_ymd_opt(mysql_time.year as i32, mysql_time.month, mysql_time.day)
.and_then(|v| {
v.and_hms_micro_opt(
mysql_time.hour,
mysql_time.minute,
mysql_time.second,
mysql_time.second_part as u32,
)
})
.ok_or_else(|| format!("Cannot parse this date: {mysql_time:?}").into())
let micro = mysql_time.second_part.try_into()?;
NaiveDate::from_ymd_opt(
mysql_time.year.try_into()?,
mysql_time.month,
mysql_time.day,
)
.and_then(|v| {
v.and_hms_micro_opt(mysql_time.hour, mysql_time.minute, mysql_time.second, micro)
})
.ok_or_else(|| format!("Cannot parse this date: {mysql_time:?}").into())
}
}

Expand Down Expand Up @@ -94,7 +94,7 @@ impl FromSql<Time, Mysql> for NaiveTime {
impl ToSql<Date, Mysql> for NaiveDate {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: 0,
Expand All @@ -114,8 +114,12 @@ impl ToSql<Date, Mysql> for NaiveDate {
impl FromSql<Date, Mysql> for NaiveDate {
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let mysql_time = <MysqlTime as FromSql<Date, Mysql>>::from_sql(bytes)?;
NaiveDate::from_ymd_opt(mysql_time.year as i32, mysql_time.month, mysql_time.day)
.ok_or_else(|| format!("Unable to convert {mysql_time:?} to chrono").into())
NaiveDate::from_ymd_opt(
mysql_time.year.try_into()?,
mysql_time.month,
mysql_time.day,
)
.ok_or_else(|| format!("Unable to convert {mysql_time:?} to chrono").into())
}
}

Expand Down
6 changes: 3 additions & 3 deletions diesel/src/mysql/types/date_and_time/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn to_time(dt: MysqlTime) -> Result<NaiveTime, Box<dyn std::error::Error>> {
("year", dt.year),
("month", dt.month),
("day", dt.day),
("offset", dt.time_zone_displacement as u32),
("offset", dt.time_zone_displacement.try_into()?),
] {
if field != 0 {
return Err(format!("Unable to convert {dt:?} to time: {name} must be 0").into());
Expand Down Expand Up @@ -63,7 +63,7 @@ fn to_primitive_datetime(dt: OffsetDateTime) -> PrimitiveDateTime {
impl ToSql<Datetime, Mysql> for PrimitiveDateTime {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: self.hour() as libc::c_uint,
Expand Down Expand Up @@ -171,7 +171,7 @@ impl FromSql<Time, Mysql> for NaiveTime {
impl ToSql<Date, Mysql> for NaiveDate {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: 0,
Expand Down
30 changes: 25 additions & 5 deletions diesel/src/mysql/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl ToSql<TinyInt, Mysql> for i8 {
impl FromSql<TinyInt, Mysql> for i8 {
fn from_sql(value: MysqlValue<'_>) -> deserialize::Result<Self> {
let bytes = value.as_bytes();
Ok(bytes[0] as i8)
Ok(i8::from_be_bytes([bytes[0]]))
}
}

Expand Down Expand Up @@ -69,12 +69,14 @@ where
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<TinyInt>, Mysql> for u8 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<TinyInt, Mysql>::to_sql(&(*self as i8), &mut out.reborrow())
out.write_u8(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<TinyInt>, Mysql> for u8 {
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i8 = FromSql::<TinyInt, Mysql>::from_sql(bytes)?;
Ok(signed as u8)
Expand All @@ -84,12 +86,18 @@ impl FromSql<Unsigned<TinyInt>, Mysql> for u8 {
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<SmallInt>, Mysql> for u16 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<SmallInt, Mysql>::to_sql(&(*self as i16), &mut out.reborrow())
out.write_u16::<NativeEndian>(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<SmallInt>, Mysql> for u16 {
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i32 = FromSql::<Integer, Mysql>::from_sql(bytes)?;
Ok(signed as u16)
Expand All @@ -99,12 +107,18 @@ impl FromSql<Unsigned<SmallInt>, Mysql> for u16 {
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<Integer>, Mysql> for u32 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<Integer, Mysql>::to_sql(&(*self as i32), &mut out.reborrow())
out.write_u32::<NativeEndian>(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<Integer>, Mysql> for u32 {
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i64 = FromSql::<BigInt, Mysql>::from_sql(bytes)?;
Ok(signed as u32)
Expand All @@ -114,12 +128,18 @@ impl FromSql<Unsigned<Integer>, Mysql> for u32 {
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<BigInt>, Mysql> for u64 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<BigInt, Mysql>::to_sql(&(*self as i64), &mut out.reborrow())
out.write_u64::<NativeEndian>(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<BigInt>, Mysql> for u64 {
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i64 = FromSql::<BigInt, Mysql>::from_sql(bytes)?;
Ok(signed as u64)
Expand Down
4 changes: 4 additions & 0 deletions diesel/src/mysql/types/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ where
}
}

#[allow(clippy::cast_possible_truncation)] // that's what we want here
fn f32_to_i64(f: f32) -> deserialize::Result<i64> {
if f <= i64::MAX as f32 && f >= i64::MIN as f32 {
Ok(f.trunc() as i64)
Expand All @@ -32,6 +33,7 @@ fn f32_to_i64(f: f32) -> deserialize::Result<i64> {
}
}

#[allow(clippy::cast_possible_truncation)] // that's what we want here
fn f64_to_i64(f: f64) -> deserialize::Result<i64> {
if f <= i64::MAX as f64 && f >= i64::MIN as f64 {
Ok(f.trunc() as i64)
Expand Down Expand Up @@ -128,6 +130,8 @@ impl FromSql<Float, Mysql> for f32 {
NumericRepresentation::Medium(x) => Ok(x as Self),
NumericRepresentation::Big(x) => Ok(x as Self),
NumericRepresentation::Float(x) => Ok(x),
// there is currently no way to do this in a better way
#[allow(clippy::cast_possible_truncation)]
NumericRepresentation::Double(x) => Ok(x as Self),
NumericRepresentation::Decimal(bytes) => Ok(str::from_utf8(bytes)?.parse()?),
}
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/mysql/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'a> MysqlValue<'a> {
pub(crate) fn numeric_value(&self) -> deserialize::Result<NumericRepresentation<'_>> {
Ok(match self.tpe {
MysqlType::UnsignedTiny | MysqlType::Tiny => {
NumericRepresentation::Tiny(self.raw[0] as i8)
NumericRepresentation::Tiny(self.raw[0].try_into()?)
}
MysqlType::UnsignedShort | MysqlType::Short => {
NumericRepresentation::Small(i16::from_ne_bytes((&self.raw[..2]).try_into()?))
Expand Down
5 changes: 4 additions & 1 deletion diesel/src/pg/connection/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ impl<'conn> BufRead for CopyToBuffer<'conn> {
let len =
pq_sys::PQgetCopyData(self.conn.internal_connection.as_ptr(), &mut self.ptr, 0);
match len {
len if len >= 0 => self.len = len as usize + 1,
len if len >= 0 => {
self.len = 1 + usize::try_from(len)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
}
-1 => self.len = 0,
_ => {
let error = self.conn.last_error_message();
Expand Down
4 changes: 3 additions & 1 deletion diesel/src/pg/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ impl RawConnection {
pq_sys::PQputCopyData(
self.internal_connection.as_ptr(),
c.as_ptr() as *const libc::c_char,
c.len() as libc::c_int,
c.len()
.try_into()
.map_err(|e| Error::SerializationError(Box::new(e)))?,
)
};
if res != 1 {
Expand Down
Loading

0 comments on commit 88b9074

Please sign in to comment.