From d360f682f8ff991d949ff7e6a4d0c13dc6738f79 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Fri, 10 Apr 2020 13:37:08 -0700 Subject: [PATCH] fix(postgres): guarantee the type name on a PgTypeInfo to always be set fixes #241 --- sqlx-core/src/describe.rs | 1 + sqlx-core/src/postgres/cursor.rs | 86 +------- sqlx-core/src/postgres/executor.rs | 325 +++++++++++++++------------- sqlx-core/src/postgres/row.rs | 10 +- sqlx-core/src/postgres/type_info.rs | 11 +- sqlx-core/src/postgres/value.rs | 6 +- tests/postgres-derives.rs | 108 ++++++++- 7 files changed, 316 insertions(+), 231 deletions(-) diff --git a/sqlx-core/src/describe.rs b/sqlx-core/src/describe.rs index c9e351a9c0..2a7c2294c6 100644 --- a/sqlx-core/src/describe.rs +++ b/sqlx-core/src/describe.rs @@ -12,6 +12,7 @@ pub struct Describe where DB: Database + ?Sized, { + // TODO: Describe#param_types should probably be Option as we either know all the params or we know none /// The expected types for the parameters of the query. pub param_types: Box<[Option]>, diff --git a/sqlx-core/src/postgres/cursor.rs b/sqlx-core/src/postgres/cursor.rs index 13f653ef72..f005623227 100644 --- a/sqlx-core/src/postgres/cursor.rs +++ b/sqlx-core/src/postgres/cursor.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::sync::Arc; use futures_core::future::BoxFuture; @@ -7,8 +6,8 @@ use crate::connection::ConnectionSource; use crate::cursor::Cursor; use crate::executor::Execute; use crate::pool::Pool; -use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription, StatementId}; -use crate::postgres::row::{Column, Statement}; +use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription}; +use crate::postgres::row::Statement; use crate::postgres::{PgArguments, PgConnection, PgRow, Postgres}; pub struct PgCursor<'c, 'q> { @@ -53,76 +52,6 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> { } } -fn parse_row_description(conn: &mut PgConnection, rd: RowDescription) -> Statement { - let mut names = HashMap::new(); - let mut columns = Vec::new(); - - columns.reserve(rd.fields.len()); - names.reserve(rd.fields.len()); - - for (index, field) in rd.fields.iter().enumerate() { - if let Some(name) = &field.name { - names.insert(name.clone(), index); - } - - let type_info = conn.get_type_info_by_oid(field.type_id.0); - - columns.push(Column { - type_info, - format: field.type_format, - }); - } - - Statement { - columns: columns.into_boxed_slice(), - names, - } -} - -// Used to describe the incoming results -// We store the column map in an Arc and share it among all rows -async fn expect_desc(conn: &mut PgConnection) -> crate::Result { - let description: Option<_> = loop { - match conn.stream.receive().await? { - Message::ParseComplete | Message::BindComplete => {} - - Message::RowDescription => { - break Some(RowDescription::read(conn.stream.buffer())?); - } - - Message::NoData => { - break None; - } - - message => { - return Err( - protocol_err!("next/describe: unexpected message: {:?}", message).into(), - ); - } - } - }; - - if let Some(description) = description { - Ok(parse_row_description(conn, description)) - } else { - Ok(Statement::default()) - } -} - -// A form of describe that uses the statement cache -async fn get_or_describe( - conn: &mut PgConnection, - id: StatementId, -) -> crate::Result> { - if !conn.cache_statement.contains_key(&id) { - let statement = expect_desc(conn).await?; - - conn.cache_statement.insert(id, Arc::new(statement)); - } - - Ok(Arc::clone(&conn.cache_statement[&id])) -} - async fn next<'a, 'c: 'a, 'q: 'a>( cursor: &'a mut PgCursor<'c, 'q>, ) -> crate::Result>> { @@ -136,9 +65,8 @@ async fn next<'a, 'c: 'a, 'q: 'a>( // If there is a statement ID, this is a non-simple or prepared query if let Some(statement) = statement { - // A prepared statement will re-use the previous column map if - // this query has been executed before - cursor.statement = get_or_describe(&mut *conn, statement).await?; + // A prepared statement will re-use the previous column map + cursor.statement = Arc::clone(&conn.cache_statement[&statement]); } // A non-prepared query must be described each time @@ -164,8 +92,12 @@ async fn next<'a, 'c: 'a, 'q: 'a>( } Message::RowDescription => { + // NOTE: This is only encountered for unprepared statements let rd = RowDescription::read(conn.stream.buffer())?; - cursor.statement = Arc::new(parse_row_description(conn, rd)); + cursor.statement = Arc::new( + conn.parse_row_description(rd, Default::default(), None, false) + .await?, + ); } Message::DataRow => { diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index 43c247075f..d5b834f4bd 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -1,5 +1,6 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt::Write; +use std::sync::Arc; use futures_core::future::BoxFuture; use futures_util::{stream, StreamExt, TryStreamExt}; @@ -9,9 +10,11 @@ use crate::cursor::Cursor; use crate::describe::{Column, Describe}; use crate::executor::{Execute, Executor, RefExecutor}; use crate::postgres::protocol::{ - self, CommandComplete, Field, Message, ParameterDescription, ReadyForQuery, RowDescription, + self, CommandComplete, Message, ParameterDescription, ReadyForQuery, RowDescription, StatementId, TypeFormat, TypeId, }; +use crate::postgres::row::Column as StatementColumn; +use crate::postgres::row::Statement; use crate::postgres::type_info::SharedStr; use crate::postgres::types::try_resolve_type_name; use crate::postgres::{ @@ -56,10 +59,137 @@ impl PgConnection { query, }); + // [Describe] will return the expected result columns and types + self.write_describe(protocol::Describe::Statement(id)); + self.write_sync(); + + // Flush commands and handle ParseComplete and RowDescription + self.wait_until_ready().await?; + self.stream.flush().await?; + self.is_ready = false; + + // wait for `ParseComplete` + match self.stream.receive().await? { + Message::ParseComplete => {} + message => { + return Err(protocol_err!("run: unexpected message: {:?}", message).into()); + } + } + + // expecting a `ParameterDescription` next + let pd = self.expect_param_desc().await?; + + // expecting a `RowDescription` next (or `NoData` for an empty statement) + let statement = self.expect_row_desc(pd).await?; + + // cache statement ID and statement description + self.cache_statement_id.insert(query.into(), id); + self.cache_statement.insert(id, Arc::new(statement)); + Ok(id) } } + async fn parse_parameter_description( + &mut self, + pd: ParameterDescription, + ) -> crate::Result> { + let mut params = Vec::with_capacity(pd.ids.len()); + + for ty in pd.ids.iter() { + let type_info = self.get_type_info_by_oid(ty.0, true).await?; + + params.push(type_info); + } + + Ok(params.into_boxed_slice()) + } + + pub(crate) async fn parse_row_description( + &mut self, + mut rd: RowDescription, + params: Box<[PgTypeInfo]>, + type_format: Option, + fetch_type_info: bool, + ) -> crate::Result { + let mut names = HashMap::new(); + let mut columns = Vec::new(); + + columns.reserve(rd.fields.len()); + names.reserve(rd.fields.len()); + + for (index, field) in rd.fields.iter_mut().enumerate() { + let name = if let Some(name) = field.name.take() { + let name = SharedStr::from(name.into_string()); + names.insert(name.clone(), index); + Some(name) + } else { + None + }; + + let type_info = self + .get_type_info_by_oid(field.type_id.0, fetch_type_info) + .await?; + + columns.push(StatementColumn { + type_info, + name, + format: type_format.unwrap_or(field.type_format), + table_id: field.table_id, + column_id: field.column_id, + }); + } + + Ok(Statement { + params, + columns: columns.into_boxed_slice(), + names, + }) + } + + async fn expect_param_desc(&mut self) -> crate::Result { + let description = match self.stream.receive().await? { + Message::ParameterDescription => ParameterDescription::read(self.stream.buffer())?, + + message => { + return Err( + protocol_err!("next/describe: unexpected message: {:?}", message).into(), + ); + } + }; + + Ok(description) + } + + // Used to describe the incoming results + // We store the column map in an Arc and share it among all rows + async fn expect_row_desc(&mut self, pd: ParameterDescription) -> crate::Result { + let description: Option<_> = match self.stream.receive().await? { + Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?), + + Message::NoData => None, + + message => { + return Err( + protocol_err!("next/describe: unexpected message: {:?}", message).into(), + ); + } + }; + + let params = self.parse_parameter_description(pd).await?; + + if let Some(description) = description { + self.parse_row_description(description, params, Some(TypeFormat::Binary), true) + .await + } else { + Ok(Statement { + params, + names: HashMap::new(), + columns: Default::default(), + }) + } + } + pub(crate) fn write_describe(&mut self, d: protocol::Describe) { self.stream.write(d); } @@ -132,12 +262,6 @@ impl PgConnection { // Next, [Bind] attaches the arguments to the statement and creates a named portal self.write_bind("", statement, &mut arguments).await?; - // Next, [Describe] will return the expected result columns and types - // Conditionally run [Describe] only if the results have not been cached - if !self.cache_statement.contains_key(&statement) { - self.write_describe(protocol::Describe::Portal("")); - } - // Next, [Execute] then executes the named portal self.write_execute("", 0); @@ -161,24 +285,6 @@ impl PgConnection { self.stream.flush().await?; self.is_ready = false; - // only cache - if let Some(statement) = statement { - // prefer redundant lookup to copying the query string - if !self.cache_statement_id.contains_key(query) { - // wait for `ParseComplete` on the stream or the - // error before we cache the statement - match self.stream.receive().await? { - Message::ParseComplete => { - self.cache_statement_id.insert(query.into(), statement); - } - - message => { - return Err(protocol_err!("run: unexpected message: {:?}", message).into()); - } - } - } - } - Ok(statement) } @@ -186,71 +292,18 @@ impl PgConnection { &'e mut self, query: &'q str, ) -> crate::Result> { - self.is_ready = false; - - let statement = self.write_prepare(query, &Default::default()).await?; - - self.write_describe(protocol::Describe::Statement(statement)); - self.write_sync(); - - self.stream.flush().await?; - - let params = loop { - match self.stream.receive().await? { - Message::ParseComplete => {} - - Message::ParameterDescription => { - break ParameterDescription::read(self.stream.buffer())?; - } - - message => { - return Err(protocol_err!( - "expected ParameterDescription; received {:?}", - message - ) - .into()); - } - }; - }; - - let result = match self.stream.receive().await? { - Message::NoData => None, - Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?), - - message => { - return Err(protocol_err!( - "expected RowDescription or NoData; received {:?}", - message - ) - .into()); - } - }; - - self.wait_until_ready().await?; - - let result_fields = result.map_or_else(Default::default, |r| r.fields); - - let type_names = self - .get_type_names( - params - .ids - .iter() - .cloned() - .chain(result_fields.iter().map(|field| field.type_id)), - ) - .await?; + let statement_id = self.write_prepare(query, &Default::default()).await?; + let statement = &self.cache_statement[&statement_id]; + let columns = statement.columns.to_vec(); Ok(Describe { - param_types: params - .ids + param_types: statement + .params .iter() - .map(|id| Some(PgTypeInfo::new(*id, &type_names[&id.0]))) + .map(|info| Some(info.clone())) .collect::>() .into_boxed_slice(), - result_columns: self - .map_result_columns(result_fields, type_names) - .await? - .into_boxed_slice(), + result_columns: self.map_result_columns(columns).await?.into_boxed_slice(), }) } @@ -277,71 +330,50 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 Ok(oid) } - pub(crate) fn get_type_info_by_oid(&mut self, oid: u32) -> PgTypeInfo { + pub(crate) async fn get_type_info_by_oid( + &mut self, + oid: u32, + fetch_type_info: bool, + ) -> crate::Result { if let Some(name) = try_resolve_type_name(oid) { - return PgTypeInfo::new(TypeId(oid), name); + return Ok(PgTypeInfo::new(TypeId(oid), name)); } if let Some(name) = self.cache_type_name.get(&oid) { - return PgTypeInfo::new(TypeId(oid), name); - } - - // NOTE: The name isn't too important for the decode lifecycle - return PgTypeInfo::new(TypeId(oid), ""); - } - - async fn get_type_names( - &mut self, - ids: impl IntoIterator, - ) -> crate::Result> { - let type_ids: HashSet = ids.into_iter().map(|id| id.0).collect::>(); - - if type_ids.is_empty() { - return Ok(HashMap::new()); + return Ok(PgTypeInfo::new(TypeId(oid), name)); } - // uppercase type names are easier to visually identify - let mut query = "select types.type_id, UPPER(pg_type.typname) from (VALUES ".to_string(); - let mut args = PgArguments::default(); - let mut pushed = false; - - // TODO: dedup this with the one below, ideally as an API we can export - for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() { - if pushed { - query += ", "; - } + let name = if fetch_type_info { + // language=SQL + let (name,): (String,) = query_as( + " + SELECT UPPER(typname) FROM pg_catalog.pg_type WHERE oid = $1 + ", + ) + .bind(oid) + .fetch_one(&mut *self) + .await?; - pushed = true; - let _ = write!(query, "(${}, ${})", bind, bind + 1); + // Emplace the new type name <-> OID association in the cache + let shared = SharedStr::from(name); - // not used in the output but ensures are values are sorted correctly - args.add(i as i32); - args.add(type_id as i32); - } + self.cache_type_oid.insert(shared.clone(), oid); + self.cache_type_name.insert(oid, shared.clone()); - query += ") as types(idx, type_id) \ - inner join pg_catalog.pg_type on pg_type.oid = type_id \ - order by types.idx"; + shared + } else { + // NOTE: The name isn't too important for the decode lifecycle of TEXT + SharedStr::Static("") + }; - crate::query::query(&query) - .bind_all(args) - .try_map(|row: PgRow| -> crate::Result<(u32, SharedStr)> { - Ok(( - row.try_get::(0)? as u32, - row.try_get::(1)?.into(), - )) - }) - .fetch(self) - .try_collect() - .await + Ok(PgTypeInfo::new(TypeId(oid), name)) } async fn map_result_columns( &mut self, - fields: Box<[Field]>, - type_names: HashMap, + columns: Vec, ) -> crate::Result>> { - if fields.is_empty() { + if columns.is_empty() { return Ok(vec![]); } @@ -349,7 +381,7 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 let mut pushed = false; let mut args = PgArguments::default(); - for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() { + for (i, (column, bind)) in columns.iter().zip((1..).step_by(3)).enumerate() { if pushed { query += ", "; } @@ -364,8 +396,8 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 ); args.add(i as i32); - args.add(field.table_id.map(|id| id as i32)); - args.add(field.column_id); + args.add(column.table_id.map(|id| id as i32)); + args.add(column.column_id); } query += ") as col(idx, table_id, col_idx) \ @@ -383,23 +415,20 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 Ok((idx, non_null)) }) .fetch(self) - .zip(stream::iter(fields.into_vec().into_iter().enumerate())) - .map(|(row, (fidx, field))| -> crate::Result> { + .zip(stream::iter(columns.into_iter().enumerate())) + .map(|(row, (fidx, column))| -> crate::Result> { let (idx, non_null) = row?; if idx != fidx as i32 { return Err( - protocol_err!("missing field from query, field: {:?}", field).into(), + protocol_err!("missing field from query, field: {:?}", column).into(), ); } Ok(Column { - name: field.name, - table_id: field.table_id, - type_info: Some(PgTypeInfo::new( - field.type_id, - &type_names[&field.type_id.0], - )), + name: column.name.map(|name| (&*name).into()), + table_id: column.table_id, + type_info: Some(column.type_info), non_null, }) }) diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 345111ffa0..169e376d90 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::postgres::protocol::{DataRow, TypeFormat}; +use crate::postgres::type_info::SharedStr; use crate::postgres::value::PgValue; use crate::postgres::{PgTypeInfo, Postgres}; use crate::row::{ColumnIndex, Row}; @@ -10,17 +11,24 @@ use crate::row::{ColumnIndex, Row}; // For Postgres, each column has an OID and a format (binary or text) // For simple (unprepared) queries, format will always be text // For prepared queries, format will _almost_ always be binary +#[derive(Clone, Debug)] pub(crate) struct Column { + pub(crate) name: Option, pub(crate) type_info: PgTypeInfo, pub(crate) format: TypeFormat, + pub(crate) table_id: Option, + pub(crate) column_id: i16, } // A statement description containing the column information used to // properly decode data #[derive(Default)] pub(crate) struct Statement { + // paramaters + pub(crate) params: Box<[PgTypeInfo]>, + // column name -> position - pub(crate) names: HashMap, usize>, + pub(crate) names: HashMap, // all columns pub(crate) columns: Box<[Column]>, diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index d3e08e2c34..e953ae8270 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -3,6 +3,7 @@ use crate::types::TypeInfo; use std::borrow::Borrow; use std::fmt; use std::fmt::Display; +use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; @@ -135,7 +136,7 @@ impl TypeInfo for PgTypeInfo { } /// Copy of `Cow` but for strings; clones guaranteed to be cheap. -#[derive(Clone, Debug, PartialEq, Hash, Eq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum SharedStr { Static(&'static str), Arc(Arc), @@ -152,6 +153,14 @@ impl Deref for SharedStr { } } +impl Hash for SharedStr { + fn hash(&self, state: &mut H) { + // Forward the hash to the string representation of this + // A derive(Hash) encodes the enum discriminant + (&**self).hash(state); + } +} + impl Borrow for SharedStr { fn borrow(&self) -> &str { &**self diff --git a/sqlx-core/src/postgres/value.rs b/sqlx-core/src/postgres/value.rs index 61d00bd636..f812e1d800 100644 --- a/sqlx-core/src/postgres/value.rs +++ b/sqlx-core/src/postgres/value.rs @@ -72,9 +72,11 @@ impl<'c> PgValue<'c> { impl<'c> RawValue<'c> for PgValue<'c> { type Database = Postgres; + // The public type_info is used for type compatibility checks fn type_info(&self) -> Option { - if let (Some(type_info), Some(_)) = (&self.type_info, &self.data) { - Some(type_info.clone()) + // For TEXT encoding the type defined on the value is unreliable + if matches!(self.data, Some(PgData::Binary(_))) { + self.type_info.clone() } else { None } diff --git a/tests/postgres-derives.rs b/tests/postgres-derives.rs index 4538612e79..c23de309bb 100644 --- a/tests/postgres-derives.rs +++ b/tests/postgres-derives.rs @@ -1,4 +1,4 @@ -use sqlx::{postgres::PgQueryAs, Executor, Postgres}; +use sqlx::{postgres::PgQueryAs, Connection, Cursor, Executor, FromRow, Postgres}; use sqlx_test::{new, test_type}; use std::fmt::Debug; @@ -16,7 +16,7 @@ enum Weak { Three = 4, } -// "Strong" enums can map to TEXT (25) or a custom enum type +// "Strong" enums can map to TEXT (25) #[derive(PartialEq, Debug, sqlx::Type)] #[sqlx(rename = "text")] #[sqlx(rename_all = "lowercase")] @@ -28,6 +28,16 @@ enum Strong { Three, } +// "Strong" enum can map to a custom type +#[derive(PartialEq, Debug, sqlx::Type)] +#[sqlx(rename = "mood")] +#[sqlx(rename_all = "lowercase")] +enum Mood { + Ok, + Happy, + Sad, +} + // Records must map to a custom type // Note that all types are types in Postgres #[derive(PartialEq, Debug, sqlx::Type)] @@ -61,6 +71,100 @@ test_type!(strong_enum( "'four'::text" == Strong::Three )); +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_enum_type() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute( + r#" +DO $$ BEGIN + +CREATE TYPE mood AS ENUM ( 'ok', 'happy', 'sad' ); + +EXCEPTION + WHEN duplicate_object THEN null; +END $$; + +CREATE TABLE IF NOT EXISTS people ( + id serial PRIMARY KEY, + mood mood not null +); + +TRUNCATE people; + "#, + ) + .await?; + + // Drop and re-acquire the connection + conn.close().await?; + let mut conn = new::().await?; + + // Select from table test + let (people_id,): (i32,) = sqlx::query_as( + " +INSERT INTO people (mood) +VALUES ($1) +RETURNING id + ", + ) + .bind(Mood::Sad) + .fetch_one(&mut conn) + .await?; + + // Drop and re-acquire the connection + conn.close().await?; + let mut conn = new::().await?; + + #[derive(sqlx::FromRow)] + struct PeopleRow { + id: i32, + mood: Mood, + } + + let rec: PeopleRow = sqlx::query_as( + " +SELECT id, mood FROM people WHERE id = $1 + ", + ) + .bind(people_id) + .fetch_one(&mut conn) + .await?; + + assert_eq!(rec.id, people_id); + assert_eq!(rec.mood, Mood::Sad); + + // Drop and re-acquire the connection + conn.close().await?; + let mut conn = new::().await?; + + let stmt = format!("SELECT id, mood FROM people WHERE id = {}", people_id); + dbg!(&stmt); + let mut cursor = conn.fetch(&*stmt); + + let row = cursor.next().await?.unwrap(); + let rec = PeopleRow::from_row(&row)?; + + assert_eq!(rec.id, people_id); + assert_eq!(rec.mood, Mood::Sad); + + // Normal type equivalency test + + let rec: (bool, Mood) = sqlx::query_as( + " +SELECT $1 = 'happy'::mood, $1 + ", + ) + .bind(&Mood::Happy) + .fetch_one(&mut conn) + .await?; + + assert!(rec.0); + assert_eq!(rec.1, Mood::Happy); + + Ok(()) +} + #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_record_type() -> anyhow::Result<()> {