Skip to content

Commit

Permalink
feat: implement Self::with_init_sql for relational databases
Browse files Browse the repository at this point in the history
  • Loading branch information
CommanderStorm committed Sep 25, 2024
1 parent 99dbabe commit 17645a3
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ kwok = []
[dependencies]
# TODO: update parse-display after MSRV>=1.80.0 bump of `testcontainer-rs` and `testcontainers-modules`
parse-display = { version = "0.9.1", optional = true, default-features = false, features = [] }
testcontainers = { version = "0.22.0" }
testcontainers = { version = "0.23.0" }


[dev-dependencies]
Expand Down Expand Up @@ -92,7 +92,7 @@ serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
surrealdb = { version = "1.2.0" }
tar = "0.4.40"
testcontainers = { version = "0.22.0", features = ["blocking"] }
testcontainers = { version = "0.23.0", features = ["blocking"] }
# To use Tiberius on macOS, rustls is needed instead of native-tls
# https://github.com/prisma/tiberius/tree/v0.12.2#encryption-tlsssl
tiberius = { version = "0.12.2", default-features = false, features = [
Expand Down
27 changes: 27 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,30 @@ pub mod zookeeper;

/// Re-exported version of `testcontainers` to avoid version conflicts
pub use testcontainers;

#[cfg(any(feature = "postgres", feature = "mariadb", feature = "mysql"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "postgres", feature = "mariadb", feature = "mysql")))
)]
/// Trait to alight interface for users across different modules.
pub trait InitSql {
/// Registers sql to be executed automatically when the container starts.
///
/// # Example
///
/// ```
/// # use testcontainers_modules::postgres::Postgres;
/// # use testcontainers_modules::InitSql;
/// let postgres_image =
/// Postgres::default().with_init_sql("CREATE EXTENSION IF NOT EXISTS hstore;");
/// ```
///
/// ```rust,ignore
/// # use testcontainers_modules::postgres::Postgres;
/// # use testcontainers_modules::rdbms::InitSql;
/// let postgres_image = Postgres::default()
/// .with_init_sql(include_str!("path_to_init.sql"));
/// ```
fn with_init_sql(self, init_sql: impl ToString) -> Self;
}
40 changes: 37 additions & 3 deletions src/mariadb/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Cow;

use testcontainers::{core::WaitFor, Image};
use testcontainers::{core::WaitFor, CopyToContainer, Image};

const NAME: &str = "mariadb";
const TAG: &str = "11.3";
Expand All @@ -27,7 +27,7 @@ const TAG: &str = "11.3";
/// [`MariaDB docker image`]: https://hub.docker.com/_/mariadb
#[derive(Debug, Default, Clone)]
pub struct Mariadb {
_priv: (),
init_sqls: Vec<CopyToContainer>,
}

impl Image for Mariadb {
Expand All @@ -54,8 +54,21 @@ impl Image for Mariadb {
("MARIADB_ALLOW_EMPTY_ROOT_PASSWORD", "1"),
]
}
fn copy_to_sources(&self) -> impl IntoIterator<Item = &CopyToContainer> {
&self.init_sqls
}
}
impl crate::InitSql for Mariadb {
fn with_init_sql(mut self, init_sql: impl ToString) -> Self {
let init_vec = init_sql.to_string().into_bytes();
let target = format!(
"/docker-entrypoint-initdb.d/init_{i}.sql",
i = self.init_sqls.len()
);
self.init_sqls.push(CopyToContainer::new(init_vec, target));
self
}
}

#[cfg(test)]
mod tests {
use mysql::prelude::Queryable;
Expand All @@ -66,6 +79,27 @@ mod tests {
testcontainers::{runners::SyncRunner, ImageExt},
};

#[test]
fn mariadb_with_init_sql() -> Result<(), Box<dyn std::error::Error + 'static>> {
use crate::InitSql;
let node = MariadbImage::default()
.with_init_sql("CREATE TABLE foo (bar varchar(255));")
.start()?;

let connection_string = &format!(
"mysql://root@{}:{}/test",
node.get_host()?,
node.get_host_port_ipv4(3306.tcp())?
);
let mut conn = mysql::Conn::new(mysql::Opts::from_url(connection_string).unwrap()).unwrap();

let rows = conn.query("INSERT INTO foo(bar) VALUES ('blub')").unwrap();
assert_eq!(rows.len(), 0);

let rows = conn.query("SELECT bar FROM foo").unwrap();
assert_eq!(rows.len(), 1);
Ok(())
}
#[test]
fn mariadb_one_plus_one() -> Result<(), Box<dyn std::error::Error + 'static>> {
let mariadb_image = MariadbImage::default();
Expand Down
41 changes: 39 additions & 2 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Cow;

use testcontainers::{core::WaitFor, Image};
use testcontainers::{core::WaitFor, CopyToContainer, Image};

const NAME: &str = "mysql";
const TAG: &str = "8.1";
Expand All @@ -27,7 +27,7 @@ const TAG: &str = "8.1";
/// [`MySQL docker image`]: https://hub.docker.com/_/mysql
#[derive(Debug, Default, Clone)]
pub struct Mysql {
_priv: (),
init_sqls: Vec<CopyToContainer>,
}

impl Image for Mysql {
Expand All @@ -54,17 +54,54 @@ impl Image for Mysql {
("MYSQL_ALLOW_EMPTY_PASSWORD", "yes"),
]
}
fn copy_to_sources(&self) -> impl IntoIterator<Item = &CopyToContainer> {
&self.init_sqls
}
}
impl crate::InitSql for Mysql {
fn with_init_sql(mut self, init_sql: impl ToString) -> Self {
let init_vec = init_sql.to_string().into_bytes();
let target = format!(
"/docker-entrypoint-initdb.d/init_{i}.sql",
i = self.init_sqls.len()
);
self.init_sqls.push(CopyToContainer::new(init_vec, target));
self
}
}

#[cfg(test)]
mod tests {
use mysql::prelude::Queryable;
use testcontainers::core::IntoContainerPort;

use crate::{
mysql::Mysql as MysqlImage,
testcontainers::{runners::SyncRunner, ImageExt},
};

#[test]
fn mysql_with_init_sql() -> Result<(), Box<dyn std::error::Error + 'static>> {
use crate::InitSql;
let node = crate::mysql::Mysql::default()
.with_init_sql("CREATE TABLE foo (bar varchar(255));")
.start()?;

let connection_string = &format!(
"mysql://root@{}:{}/test",
node.get_host()?,
node.get_host_port_ipv4(3306.tcp())?
);
let mut conn = mysql::Conn::new(mysql::Opts::from_url(connection_string).unwrap()).unwrap();

let rows = conn.query("INSERT INTO foo(bar) VALUES ('blub')").unwrap();
assert_eq!(rows.len(), 0);

let rows = conn.query("SELECT bar FROM foo").unwrap();
assert_eq!(rows.len(), 1);
Ok(())
}

#[test]
fn mysql_one_plus_one() -> Result<(), Box<dyn std::error::Error + 'static>> {
let mysql_image = MysqlImage::default();
Expand Down
49 changes: 47 additions & 2 deletions src/postgres/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::{borrow::Cow, collections::HashMap};

use testcontainers::{core::WaitFor, Image};
use testcontainers::{
core::{Mount, WaitFor},
CopyToContainer, Image,
};

const NAME: &str = "postgres";
const TAG: &str = "11-alpine";
Expand Down Expand Up @@ -30,6 +33,7 @@ const TAG: &str = "11-alpine";
#[derive(Debug, Clone)]
pub struct Postgres {
env_vars: HashMap<String, String>,
init_sqls: Vec<CopyToContainer>,
}

impl Postgres {
Expand Down Expand Up @@ -62,6 +66,17 @@ impl Postgres {
self
}
}
impl crate::InitSql for Postgres {
fn with_init_sql(mut self, init_sql: impl ToString) -> Self {
let init_vec = init_sql.to_string().into_bytes();
let target = format!(
"/docker-entrypoint-initdb.d/init_{i}.sql",
i = self.init_sqls.len()
);
self.init_sqls.push(CopyToContainer::new(init_vec, target));
self
}
}

impl Default for Postgres {
fn default() -> Self {
Expand All @@ -70,7 +85,10 @@ impl Default for Postgres {
env_vars.insert("POSTGRES_USER".to_owned(), "postgres".to_owned());
env_vars.insert("POSTGRES_PASSWORD".to_owned(), "postgres".to_owned());

Self { env_vars }
Self {
env_vars,
init_sqls: Vec::new(),
}
}
}

Expand All @@ -95,6 +113,9 @@ impl Image for Postgres {
) -> impl IntoIterator<Item = (impl Into<Cow<'_, str>>, impl Into<Cow<'_, str>>)> {
&self.env_vars
}
fn copy_to_sources(&self) -> impl IntoIterator<Item = &CopyToContainer> {
&self.init_sqls
}
}

#[cfg(test)]
Expand Down Expand Up @@ -144,4 +165,28 @@ mod tests {
assert!(first_column.contains("13"));
Ok(())
}

#[test]
fn postgres_with_init_sql() -> Result<(), Box<dyn std::error::Error + 'static>> {
use crate::InitSql;
let node = Postgres::default()
.with_init_sql("CREATE TABLE foo (bar varchar(255));")
.start()?;

let connection_string = &format!(
"postgres://postgres:postgres@{}:{}/postgres",
node.get_host()?,
node.get_host_port_ipv4(5432)?
);
let mut conn = postgres::Client::connect(connection_string, postgres::NoTls).unwrap();

let rows = conn
.query("INSERT INTO foo(bar) VALUES ($1)", &[&"blub"])
.unwrap();
assert_eq!(rows.len(), 0);

let rows = conn.query("SELECT bar FROM foo", &[]).unwrap();
assert_eq!(rows.len(), 1);
Ok(())
}
}

0 comments on commit 17645a3

Please sign in to comment.