diff --git a/relate/src/main/scala/com/lucidchart/relate/CollectionsParser.scala b/relate/src/main/scala/com/lucidchart/relate/CollectionsParser.scala deleted file mode 100644 index c01a393..0000000 --- a/relate/src/main/scala/com/lucidchart/relate/CollectionsParser.scala +++ /dev/null @@ -1,42 +0,0 @@ -package com.lucidchart.relate - -import scala.collection.compat._ -import scala.language.higherKinds - -trait CollectionsParser { - def limitedCollection[B: RowParser, Col[_]](maxRows: Long)(implicit factory: Factory[B, Col[B]]) = - RowParser { result => - val builder = factory.newBuilder - - result.withResultSet { resultSet => - while (resultSet.getRow < maxRows && resultSet.next()) { - builder += implicitly[RowParser[B]].parse(result) - } - } - - builder.result - } - - implicit def option[B: RowParser] = RowParser[Option[B]] { result => - limitedCollection[B, List](1).parse(result).headOption - } - - implicit def collection[B: RowParser, Col[_]](implicit factory: Factory[B, Col[B]]) = - limitedCollection[B, Col](Long.MaxValue) - - implicit def pairCollection[Key: RowParser, Value: RowParser, PairCol[_, _]](implicit - factory: Factory[(Key, Value), PairCol[Key, Value]] - ) = - RowParser { result => - - val builder = factory.newBuilder - - result.withResultSet { resultSet => - while (resultSet.getRow < Long.MaxValue && resultSet.next()) { - builder += implicitly[RowParser[Key]].parse(result) -> implicitly[RowParser[Value]].parse(result) - } - } - - builder.result - } -} diff --git a/relate/src/main/scala/com/lucidchart/relate/CollectionsSqlResult.scala b/relate/src/main/scala/com/lucidchart/relate/CollectionsSqlResult.scala index eac76bc..d3e1c46 100644 --- a/relate/src/main/scala/com/lucidchart/relate/CollectionsSqlResult.scala +++ b/relate/src/main/scala/com/lucidchart/relate/CollectionsSqlResult.scala @@ -18,7 +18,7 @@ trait CollectionsSqlResult { self: SqlResult => withResultSet { resultSet => while (resultSet.getRow < maxRows && resultSet.next()) { - builder += parser(asRow) + builder += parser(SqlRow(resultSet)) } } @@ -41,7 +41,7 @@ trait CollectionsSqlResult { self: SqlResult => withResultSet { resultSet => while (resultSet.getRow < maxRows && resultSet.next()) { - builder += parser(asRow) + builder += parser(SqlRow(resultSet)) } } diff --git a/relate/src/main/scala/com/lucidchart/relate/ResultSetWrapper.scala b/relate/src/main/scala/com/lucidchart/relate/ResultSetWrapper.scala deleted file mode 100644 index 0b05981..0000000 --- a/relate/src/main/scala/com/lucidchart/relate/ResultSetWrapper.scala +++ /dev/null @@ -1,43 +0,0 @@ -package com.lucidchart.relate - -import java.sql.SQLException - -trait ResultSetWrapper { - val resultSet: java.sql.ResultSet - - /** - * Determine if the last value extracted from the result set was null - * @return - * whether the last value was null - */ - def wasNull(): Boolean = resultSet.wasNull() - - def next(): Boolean = resultSet.next() - - def withResultSet[A](f: (java.sql.ResultSet) => A) = { - try { - f(resultSet) - } finally { - resultSet.close() - } - } - - /** - * Determine if the result set contains the given column name - * @param column - * the column name to check - * @return - * whether or not the result set contains that column name - */ - def hasColumn(column: String): Boolean = { - try { - resultSet.findColumn(column) - true - } catch { - case e: SQLException => false - } - } - - private[relate] def asRow: SqlRow = SqlRow(resultSet) - private[relate] def asResult: SqlResult = SqlResult(resultSet) -} diff --git a/relate/src/main/scala/com/lucidchart/relate/RowIterator.scala b/relate/src/main/scala/com/lucidchart/relate/RowIterator.scala index 1c1d952..875a8ed 100644 --- a/relate/src/main/scala/com/lucidchart/relate/RowIterator.scala +++ b/relate/src/main/scala/com/lucidchart/relate/RowIterator.scala @@ -10,7 +10,7 @@ private[relate] object RowIterator { private[relate] class RowIterator[A](parser: SqlRow => A, stmt: PreparedStatement, result: SqlResult) extends Iterator[A] { - private var _hasNext = result.next() + private var _hasNext = result.resultSet.next() /** * Make certain that all resources are closed @@ -32,9 +32,9 @@ private[relate] class RowIterator[A](parser: SqlRow => A, stmt: PreparedStatemen * the parsed record */ override def next(): A = { - val ret = parser(result.asRow) + val ret = parser(SqlRow(result.resultSet)) if (_hasNext) { - _hasNext = result.next() + _hasNext = result.resultSet.next() } // if we've iterated through the whole thing, close resources diff --git a/relate/src/main/scala/com/lucidchart/relate/RowParser.scala b/relate/src/main/scala/com/lucidchart/relate/RowParser.scala index 8306118..b45a549 100644 --- a/relate/src/main/scala/com/lucidchart/relate/RowParser.scala +++ b/relate/src/main/scala/com/lucidchart/relate/RowParser.scala @@ -11,7 +11,7 @@ trait RowParser[A] extends (SqlRow => A) { def apply(row: SqlRow) = parse(row) } -object RowParser extends CollectionsParser { +object RowParser { def apply[A](f: (SqlRow) => A) = new RowParser[A] { def parse(row: SqlRow) = f(row) } @@ -25,24 +25,4 @@ object RowParser extends CollectionsParser { private[relate] val insertInt = (row: SqlRow) => row.strictInt(1) private[relate] val insertLong = (row: SqlRow) => row.strictLong(1) - - implicit def multiMap[Key: RowParser, Value: RowParser] = RowParser[Map[Key, Set[Value]]] { result => - val mm: mutable.Map[Key, Set[Value]] = new mutable.HashMap[Key, Set[Value]] - - result.withResultSet { resultSet => - while (resultSet.next()) { - val key = implicitly[RowParser[Key]].parse(result) - val value = implicitly[RowParser[Value]].parse(result) - - mm.get(key) - .map { foundValue => - mm += (key -> (foundValue + value)) - } - .getOrElse { - mm += (key -> Set(value)) - } - } - } - mm.toMap - } } diff --git a/relate/src/main/scala/com/lucidchart/relate/SqlQuery.scala b/relate/src/main/scala/com/lucidchart/relate/SqlQuery.scala index 5e744c0..b8383f6 100644 --- a/relate/src/main/scala/com/lucidchart/relate/SqlQuery.scala +++ b/relate/src/main/scala/com/lucidchart/relate/SqlQuery.scala @@ -153,8 +153,6 @@ trait Sql extends CollectionsSql { def executeInsertSingle[U](parser: SqlRow => U)(implicit connection: Connection): U = insertionStatement.execute(_.asSingle(parser)) - def as[A: RowParser]()(implicit connection: Connection): A = normalStatement.execute(_.as[A]) - /** * Execute this query and get back the result as a single record * @param parser diff --git a/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala b/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala index 9b1014e..85d80dc 100644 --- a/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala +++ b/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala @@ -3,6 +3,7 @@ package com.lucidchart.relate import java.sql.ResultSetMetaData import scala.collection.mutable import scala.language.higherKinds +import scala.util.Using object SqlResult { def apply(resultSet: java.sql.ResultSet) = new SqlResult(resultSet) @@ -19,9 +20,8 @@ object SqlResult { * The extraction methods (int, string, long, etc.) also have "strict" counterparts. The "strict" methods are slightly * faster, but do not do type checking or handle null values. */ -class SqlResult(val resultSet: java.sql.ResultSet) extends ResultSetWrapper with CollectionsSqlResult { - - def as[A: RowParser](): A = implicitly[RowParser[A]].parse(asRow) +class SqlResult(private[relate] val resultSet: java.sql.ResultSet) extends CollectionsSqlResult { + private[relate] def withResultSet[A](f: (java.sql.ResultSet) => A) = Using.resource(resultSet)(f) def asSingle[A: RowParser](): A = asCollection[A, Seq](1).head def asSingle[A](parser: SqlRow => A): A = asCollection[A, Seq](parser, 1).head @@ -42,7 +42,7 @@ class SqlResult(val resultSet: java.sql.ResultSet) extends ResultSetWrapper with val mm: mutable.MultiMap[U, V] = new mutable.HashMap[U, mutable.Set[V]] with mutable.MultiMap[U, V] withResultSet { resultSet => while (resultSet.next()) { - val parsed = parser(asRow) + val parsed = parser(SqlRow(resultSet)) mm.addBinding(parsed._1, parsed._2) } } @@ -50,12 +50,8 @@ class SqlResult(val resultSet: java.sql.ResultSet) extends ResultSetWrapper with } def asScalar[A](): A = asScalarOption.get - def asScalarOption[A](): Option[A] = { - if (resultSet.next()) { - Some(resultSet.getObject(1).asInstanceOf[A]) - } else { - None - } + def asScalarOption[A](): Option[A] = withResultSet { resultSet => + Option.when(resultSet.next())(resultSet.getObject(1).asInstanceOf[A]) } /** diff --git a/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala b/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala index 1cf7bf4..b6f5d2e 100644 --- a/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala +++ b/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala @@ -13,7 +13,7 @@ object SqlRow { def apply(rs: java.sql.ResultSet): SqlRow = new SqlRow(rs) } -class SqlRow(val resultSet: java.sql.ResultSet) extends ResultSetWrapper { +class SqlRow(resultSet: java.sql.ResultSet) { /** * Get the number of the row the SqlResult is currently on diff --git a/relate/src/test/scala/ImplicitParsingTest.scala b/relate/src/test/scala/ImplicitParsingTest.scala deleted file mode 100644 index c9f9a23..0000000 --- a/relate/src/test/scala/ImplicitParsingTest.scala +++ /dev/null @@ -1,142 +0,0 @@ -package com.lucidchart.relate - -import java.sql.Connection -import org.specs2.mock.Mockito -import org.specs2.mutable._ - -class ImplicitParsingTest extends Specification with Mockito { - def getMocks = { - val rs = mock[java.sql.ResultSet] - (rs, SqlResult(rs)) - } - - implicit val con: Connection = null - - case class TestRecord(name: String) - - object TestRecord { - implicit val praser = new RowParser[TestRecord] { - def parse(result: SqlRow): TestRecord = { - TestRecord(result.string("name")) - } - } - } - - case class TestKey(key: String) - - object TestKey { - implicit val parse = new RowParser[TestKey] { - def parse(result: SqlRow): TestKey = { - TestKey(result.string("key")) - } - } - } - - "RowParser" should { - "build a list" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 - rs.next returns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" - - result.as[List[TestRecord]] mustEqual List( - TestRecord("hello"), - TestRecord("world") - ) - - success - } - - "build a seq" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 - rs.next returns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" - - result.as[Seq[TestRecord]] mustEqual Seq( - TestRecord("hello"), - TestRecord("world") - ) - - success - } - - "build an iterable" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 - rs.next returns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" - - result.as[Iterable[TestRecord]] mustEqual Iterable( - TestRecord("hello"), - TestRecord("world") - ) - - success - } - - "build an iterable" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 - rs.next returns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" - - result.as[Iterable[TestRecord]] mustEqual Iterable( - TestRecord("hello"), - TestRecord("world") - ) - - success - } - - "build a map" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 - rs.next returns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" - rs.getString("key") returns "1" thenReturns "2" - - result.as[Map[TestKey, TestRecord]] mustEqual Map( - TestKey("1") -> TestRecord("hello"), - TestKey("2") -> TestRecord("world") - ) - } - - "build a multi-map" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 thenReturns 3 - rs.next returns true thenReturns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" thenReturns "relate" - rs.getString("key") returns "1" thenReturns "2" thenReturns "1" - - result.as[Map[TestKey, Set[TestRecord]]] mustEqual Map( - TestKey("1") -> Set(TestRecord("hello"), TestRecord("relate")), - TestKey("2") -> Set(TestRecord("world")) - ) - } - - "build an option of something" in { - val (rs, result) = getMocks - - rs.getRow returns 0 thenReturns 1 thenReturns 2 thenReturns 3 - rs.next returns true thenReturns true thenReturns true thenReturns false - rs.getString("name") returns "hello" thenReturns "world" thenReturns "relate" - - result.as[Option[TestRecord]] mustEqual Some(TestRecord("hello")) - } - - "build a None of something" in { - val (rs, result) = getMocks - - rs.next returns false - - result.as[Option[TestRecord]] mustEqual None - } - } -} diff --git a/relate/src/test/scala/SqlResultSpec.scala b/relate/src/test/scala/SqlResultSpec.scala index 7b7f9ce..da07ab6 100644 --- a/relate/src/test/scala/SqlResultSpec.scala +++ b/relate/src/test/scala/SqlResultSpec.scala @@ -256,6 +256,27 @@ class SqlResultSpec extends Specification with Mockito { result.asScalarOption[Long] must_== None } + + "close the ResultSet" in { + val (rs, _, result) = getMocks + + rs.next returns true + rs.getObject(1) returns (2: java.lang.Long) + + result.asScalar[Long] mustEqual 2L + + there was one(rs).close() + } + + "close the ResultSet if it was empty" in { + val (rs, _, result) = getMocks + + rs.next returns false + + result.asScalarOption[Long] must beNone + + there was one(rs).close() + } } "extractOption" should { @@ -307,15 +328,6 @@ class SqlResultSpec extends Specification with Mockito { } } - "wasNull" should { - "return true if the last read was null" in { - val (rs, _, result) = getMocks - - rs.wasNull() returns true - result.wasNull mustEqual true - } - } - "strictArray" should { "properly pass through the call to ResultSet" in { val (rs, row, _) = getMocks