Skip to content

Commit

Permalink
Merge pull request diesel-rs#4170 from weiznich/prevent_protocol_leve…
Browse files Browse the repository at this point in the history
…l_size_overflows

Enable some numeric cast releated clippy lints and fix them in the code base
  • Loading branch information
weiznich committed Aug 23, 2024
1 parent 1a61cd3 commit 3b624ed
Show file tree
Hide file tree
Showing 41 changed files with 352 additions and 153 deletions.
5 changes: 4 additions & 1 deletion diesel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@
clippy::enum_glob_use,
clippy::if_not_else,
clippy::items_after_statements,
clippy::used_underscore_binding
clippy::used_underscore_binding,
clippy::cast_possible_wrap,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
#![deny(unsafe_code)]
#![cfg_attr(test, allow(clippy::map_unwrap_or, clippy::unwrap_used))]
Expand Down
20 changes: 15 additions & 5 deletions diesel/src/mysql/connection/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ impl Clone for BindData {
// written. At the time of writing this comment, the `BindData::bind_for_truncated_data`
// function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding
// invariant.
std::slice::from_raw_parts(ptr.as_ptr(), self.length as usize)
std::slice::from_raw_parts(
ptr.as_ptr(),
self.length.try_into().expect("usize is at least 32bit"),
)
};
let mut vec = slice.to_owned();
let ptr = NonNull::new(vec.as_mut_ptr());
Expand Down Expand Up @@ -416,7 +419,10 @@ impl BindData {
// written. At the time of writing this comment, the `BindData::bind_for_truncated_data`
// function is only called by `Binds::populate_dynamic_buffers` which ensures the corresponding
// invariant.
std::slice::from_raw_parts(data.as_ptr(), self.length as usize)
std::slice::from_raw_parts(
data.as_ptr(),
self.length.try_into().expect("Usize is at least 32 bit"),
)
};
Some(MysqlValue::new_internal(slice, tpe))
}
Expand All @@ -429,7 +435,10 @@ impl BindData {
fn update_buffer_length(&mut self) {
use std::cmp::min;

let actual_bytes_in_buffer = min(self.capacity, self.length as usize);
let actual_bytes_in_buffer = min(
self.capacity,
self.length.try_into().expect("Usize is at least 32 bit"),
);
self.length = actual_bytes_in_buffer as libc::c_ulong;
}

Expand Down Expand Up @@ -475,7 +484,8 @@ impl BindData {
self.bytes = None;

let offset = self.capacity;
let truncated_amount = self.length as usize - offset;
let truncated_amount =
usize::try_from(self.length).expect("Usize is at least 32 bit") - offset;

debug_assert!(
truncated_amount > 0,
Expand Down Expand Up @@ -505,7 +515,7 @@ impl BindData {
// offset is zero here as we don't have a buffer yet
// we know the requested length here so we can just request
// the correct size
let mut vec = vec![0_u8; self.length as usize];
let mut vec = vec![0_u8; self.length.try_into().expect("usize is at least 32 bit")];
self.capacity = vec.capacity();
self.bytes = NonNull::new(vec.as_mut_ptr());
mem::forget(vec);
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/mysql/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl Connection for MysqlConnection {
// we have not called result yet, so calling `execute` is
// fine
let stmt_use = unsafe { stmt.execute() }?;
Ok(stmt_use.affected_rows())
stmt_use.affected_rows()
}),
&mut self.transaction_state,
&mut self.instrumentation,
Expand Down
28 changes: 18 additions & 10 deletions diesel/src/mysql/connection/stmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ pub(super) struct StatementUse<'a> {
}

impl<'a> StatementUse<'a> {
pub(in crate::mysql::connection) fn affected_rows(&self) -> usize {
pub(in crate::mysql::connection) fn affected_rows(&self) -> QueryResult<usize> {
let affected_rows = unsafe { ffi::mysql_stmt_affected_rows(self.inner.stmt.as_ptr()) };
affected_rows as usize
affected_rows
.try_into()
.map_err(|e| Error::DeserializationError(Box::new(e)))
}

/// This function should be called after `execute` only
Expand All @@ -167,14 +169,19 @@ impl<'a> StatementUse<'a> {

pub(super) fn populate_row_buffers(&self, binds: &mut OutputBinds) -> QueryResult<Option<()>> {
let next_row_result = unsafe { ffi::mysql_stmt_fetch(self.inner.stmt.as_ptr()) };
match next_row_result as libc::c_uint {
ffi::MYSQL_NO_DATA => Ok(None),
ffi::MYSQL_DATA_TRUNCATED => binds.populate_dynamic_buffers(self).map(Some),
0 => {
binds.update_buffer_lengths();
Ok(Some(()))
if next_row_result < 0 {
self.inner.did_an_error_occur().map(Some)
} else {
#[allow(clippy::cast_sign_loss)] // that's how it's supposed to be based on the API
match next_row_result as libc::c_uint {
ffi::MYSQL_NO_DATA => Ok(None),
ffi::MYSQL_DATA_TRUNCATED => binds.populate_dynamic_buffers(self).map(Some),
0 => {
binds.update_buffer_lengths();
Ok(Some(()))
}
_error => self.inner.did_an_error_occur().map(Some),
}
_error => self.inner.did_an_error_occur().map(Some),
}
}

Expand All @@ -187,7 +194,8 @@ impl<'a> StatementUse<'a> {
ffi::mysql_stmt_fetch_column(
self.inner.stmt.as_ptr(),
bind,
idx as libc::c_uint,
idx.try_into()
.map_err(|e| Error::DeserializationError(Box::new(e)))?,
offset as libc::c_ulong,
);
self.inner.did_an_error_occur()
Expand Down
32 changes: 18 additions & 14 deletions diesel/src/mysql/types/date_and_time/chrono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl FromSql<Datetime, Mysql> for NaiveDateTime {
impl ToSql<Timestamp, Mysql> for NaiveDateTime {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: self.hour() as libc::c_uint,
Expand All @@ -48,16 +48,16 @@ impl FromSql<Timestamp, Mysql> for NaiveDateTime {
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let mysql_time = <MysqlTime as FromSql<Timestamp, Mysql>>::from_sql(bytes)?;

NaiveDate::from_ymd_opt(mysql_time.year as i32, mysql_time.month, mysql_time.day)
.and_then(|v| {
v.and_hms_micro_opt(
mysql_time.hour,
mysql_time.minute,
mysql_time.second,
mysql_time.second_part as u32,
)
})
.ok_or_else(|| format!("Cannot parse this date: {mysql_time:?}").into())
let micro = mysql_time.second_part.try_into()?;
NaiveDate::from_ymd_opt(
mysql_time.year.try_into()?,
mysql_time.month,
mysql_time.day,
)
.and_then(|v| {
v.and_hms_micro_opt(mysql_time.hour, mysql_time.minute, mysql_time.second, micro)
})
.ok_or_else(|| format!("Cannot parse this date: {mysql_time:?}").into())
}
}

Expand Down Expand Up @@ -94,7 +94,7 @@ impl FromSql<Time, Mysql> for NaiveTime {
impl ToSql<Date, Mysql> for NaiveDate {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: 0,
Expand All @@ -114,8 +114,12 @@ impl ToSql<Date, Mysql> for NaiveDate {
impl FromSql<Date, Mysql> for NaiveDate {
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let mysql_time = <MysqlTime as FromSql<Date, Mysql>>::from_sql(bytes)?;
NaiveDate::from_ymd_opt(mysql_time.year as i32, mysql_time.month, mysql_time.day)
.ok_or_else(|| format!("Unable to convert {mysql_time:?} to chrono").into())
NaiveDate::from_ymd_opt(
mysql_time.year.try_into()?,
mysql_time.month,
mysql_time.day,
)
.ok_or_else(|| format!("Unable to convert {mysql_time:?} to chrono").into())
}
}

Expand Down
6 changes: 3 additions & 3 deletions diesel/src/mysql/types/date_and_time/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn to_time(dt: MysqlTime) -> Result<NaiveTime, Box<dyn std::error::Error>> {
("year", dt.year),
("month", dt.month),
("day", dt.day),
("offset", dt.time_zone_displacement as u32),
("offset", dt.time_zone_displacement.try_into()?),
] {
if field != 0 {
return Err(format!("Unable to convert {dt:?} to time: {name} must be 0").into());
Expand Down Expand Up @@ -63,7 +63,7 @@ fn to_primitive_datetime(dt: OffsetDateTime) -> PrimitiveDateTime {
impl ToSql<Datetime, Mysql> for PrimitiveDateTime {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: self.hour() as libc::c_uint,
Expand Down Expand Up @@ -171,7 +171,7 @@ impl FromSql<Time, Mysql> for NaiveTime {
impl ToSql<Date, Mysql> for NaiveDate {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
let mysql_time = MysqlTime {
year: self.year() as libc::c_uint,
year: self.year().try_into()?,
month: self.month() as libc::c_uint,
day: self.day() as libc::c_uint,
hour: 0,
Expand Down
30 changes: 25 additions & 5 deletions diesel/src/mysql/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl ToSql<TinyInt, Mysql> for i8 {
impl FromSql<TinyInt, Mysql> for i8 {
fn from_sql(value: MysqlValue<'_>) -> deserialize::Result<Self> {
let bytes = value.as_bytes();
Ok(bytes[0] as i8)
Ok(i8::from_be_bytes([bytes[0]]))
}
}

Expand Down Expand Up @@ -69,12 +69,14 @@ where
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<TinyInt>, Mysql> for u8 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<TinyInt, Mysql>::to_sql(&(*self as i8), &mut out.reborrow())
out.write_u8(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<TinyInt>, Mysql> for u8 {
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i8 = FromSql::<TinyInt, Mysql>::from_sql(bytes)?;
Ok(signed as u8)
Expand All @@ -84,12 +86,18 @@ impl FromSql<Unsigned<TinyInt>, Mysql> for u8 {
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<SmallInt>, Mysql> for u16 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<SmallInt, Mysql>::to_sql(&(*self as i16), &mut out.reborrow())
out.write_u16::<NativeEndian>(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<SmallInt>, Mysql> for u16 {
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i32 = FromSql::<Integer, Mysql>::from_sql(bytes)?;
Ok(signed as u16)
Expand All @@ -99,12 +107,18 @@ impl FromSql<Unsigned<SmallInt>, Mysql> for u16 {
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<Integer>, Mysql> for u32 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<Integer, Mysql>::to_sql(&(*self as i32), &mut out.reborrow())
out.write_u32::<NativeEndian>(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<Integer>, Mysql> for u32 {
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i64 = FromSql::<BigInt, Mysql>::from_sql(bytes)?;
Ok(signed as u32)
Expand All @@ -114,12 +128,18 @@ impl FromSql<Unsigned<Integer>, Mysql> for u32 {
#[cfg(feature = "mysql_backend")]
impl ToSql<Unsigned<BigInt>, Mysql> for u64 {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
ToSql::<BigInt, Mysql>::to_sql(&(*self as i64), &mut out.reborrow())
out.write_u64::<NativeEndian>(*self)?;
Ok(IsNull::No)
}
}

#[cfg(feature = "mysql_backend")]
impl FromSql<Unsigned<BigInt>, Mysql> for u64 {
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_possible_truncation
)] // that's what we want
fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result<Self> {
let signed: i64 = FromSql::<BigInt, Mysql>::from_sql(bytes)?;
Ok(signed as u64)
Expand Down
4 changes: 4 additions & 0 deletions diesel/src/mysql/types/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ where
}
}

#[allow(clippy::cast_possible_truncation)] // that's what we want here
fn f32_to_i64(f: f32) -> deserialize::Result<i64> {
if f <= i64::MAX as f32 && f >= i64::MIN as f32 {
Ok(f.trunc() as i64)
Expand All @@ -32,6 +33,7 @@ fn f32_to_i64(f: f32) -> deserialize::Result<i64> {
}
}

#[allow(clippy::cast_possible_truncation)] // that's what we want here
fn f64_to_i64(f: f64) -> deserialize::Result<i64> {
if f <= i64::MAX as f64 && f >= i64::MIN as f64 {
Ok(f.trunc() as i64)
Expand Down Expand Up @@ -128,6 +130,8 @@ impl FromSql<Float, Mysql> for f32 {
NumericRepresentation::Medium(x) => Ok(x as Self),
NumericRepresentation::Big(x) => Ok(x as Self),
NumericRepresentation::Float(x) => Ok(x),
// there is currently no way to do this in a better way
#[allow(clippy::cast_possible_truncation)]
NumericRepresentation::Double(x) => Ok(x as Self),
NumericRepresentation::Decimal(bytes) => Ok(str::from_utf8(bytes)?.parse()?),
}
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/mysql/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'a> MysqlValue<'a> {
pub(crate) fn numeric_value(&self) -> deserialize::Result<NumericRepresentation<'_>> {
Ok(match self.tpe {
MysqlType::UnsignedTiny | MysqlType::Tiny => {
NumericRepresentation::Tiny(self.raw[0] as i8)
NumericRepresentation::Tiny(self.raw[0].try_into()?)
}
MysqlType::UnsignedShort | MysqlType::Short => {
NumericRepresentation::Small(i16::from_ne_bytes((&self.raw[..2]).try_into()?))
Expand Down
5 changes: 4 additions & 1 deletion diesel/src/pg/connection/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ impl<'conn> BufRead for CopyToBuffer<'conn> {
let len =
pq_sys::PQgetCopyData(self.conn.internal_connection.as_ptr(), &mut self.ptr, 0);
match len {
len if len >= 0 => self.len = len as usize + 1,
len if len >= 0 => {
self.len = 1 + usize::try_from(len)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?
}
-1 => self.len = 0,
_ => {
let error = self.conn.last_error_message();
Expand Down
4 changes: 3 additions & 1 deletion diesel/src/pg/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ impl RawConnection {
pq_sys::PQputCopyData(
self.internal_connection.as_ptr(),
c.as_ptr() as *const libc::c_char,
c.len() as libc::c_int,
c.len()
.try_into()
.map_err(|e| Error::SerializationError(Box::new(e)))?,
)
};
if res != 1 {
Expand Down
Loading

0 comments on commit 3b624ed

Please sign in to comment.