diff --git a/lib/Doctrine/DBAL/Portability/Statement.php b/lib/Doctrine/DBAL/Portability/Statement.php index bd9a445fe59..48d8e4ac96c 100644 --- a/lib/Doctrine/DBAL/Portability/Statement.php +++ b/lib/Doctrine/DBAL/Portability/Statement.php @@ -19,6 +19,7 @@ namespace Doctrine\DBAL\Portability; +use Doctrine\DBAL\Driver\StatementIterator; use Doctrine\DBAL\FetchMode; use Doctrine\DBAL\ParameterType; use function array_change_key_case; @@ -139,9 +140,7 @@ public function setFetchMode($fetchMode, $arg1 = null, $arg2 = null) */ public function getIterator() { - $data = $this->fetchAll(); - - return new \ArrayIterator($data); + return new StatementIterator($this); } /** diff --git a/tests/Doctrine/Tests/DBAL/Driver/StatementIteratorTest.php b/tests/Doctrine/Tests/DBAL/Driver/StatementIteratorTest.php index 82b2790fc19..95f107cf921 100644 --- a/tests/Doctrine/Tests/DBAL/Driver/StatementIteratorTest.php +++ b/tests/Doctrine/Tests/DBAL/Driver/StatementIteratorTest.php @@ -2,39 +2,101 @@ namespace Doctrine\Tests\DBAL\Driver; +use Doctrine\DBAL\Driver\IBMDB2\DB2Statement; +use Doctrine\DBAL\Driver\Mysqli\MysqliStatement; +use Doctrine\DBAL\Driver\OCI8\OCI8Statement; +use Doctrine\DBAL\Driver\SQLAnywhere\SQLAnywhereStatement; +use Doctrine\DBAL\Driver\SQLSrv\SQLSrvStatement; use Doctrine\DBAL\Driver\Statement; use Doctrine\DBAL\Driver\StatementIterator; +use Doctrine\DBAL\Portability\Statement as PortabilityStatement; +use IteratorAggregate; +use PHPUnit\Framework\MockObject\MockObject; +use Traversable; +use function extension_loaded; class StatementIteratorTest extends \Doctrine\Tests\DbalTestCase { - public function testGettingIteratorDoesNotCallFetch() + /** + * @dataProvider statementProvider() + */ + public function testGettingIteratorDoesNotCallFetch(string $class) : void { - $stmt = $this->createMock(Statement::class); + /** @var IteratorAggregate|MockObject $stmt */ + $stmt = $this->createPartialMock($class, ['fetch', 'fetchAll', 'fetchColumn']); $stmt->expects($this->never())->method('fetch'); $stmt->expects($this->never())->method('fetchAll'); $stmt->expects($this->never())->method('fetchColumn'); + $stmt->getIterator(); + } + + public function testIteratorIterationCallsFetchOncePerStep() : void + { + $stmt = $this->createMock(Statement::class); + + $calls = 0; + $this->configureStatement($stmt, $calls); + $stmtIterator = new StatementIterator($stmt); - $stmtIterator->getIterator(); + + $this->assertIterationCallsFetchOncePerStep($stmtIterator, $calls); } - public function testIterationCallsFetchOncePerStep() + /** + * @dataProvider statementProvider() + */ + public function testStatementIterationCallsFetchOncePerStep(string $class) : void + { + $stmt = $this->createPartialMock($class, ['fetch']); + + $calls = 0; + $this->configureStatement($stmt, $calls); + $this->assertIterationCallsFetchOncePerStep($stmt, $calls); + } + + private function configureStatement(MockObject $stmt, int &$calls) : void { $values = ['foo', '', 'bar', '0', 'baz', 0, 'qux', null, 'quz', false, 'impossible']; $calls = 0; - $stmt = $this->createMock(Statement::class); $stmt->expects($this->exactly(10)) ->method('fetch') ->willReturnCallback(function() use ($values, &$calls) { $value = $values[$calls]; $calls++; + return $value; }); + } - $stmtIterator = new StatementIterator($stmt); - foreach ($stmtIterator as $i => $_) { + private function assertIterationCallsFetchOncePerStep(Traversable $iterator, int &$calls) : void + { + foreach ($iterator as $i => $_) { $this->assertEquals($i + 1, $calls); } } + + /** + * @return string[][] + */ + public static function statementProvider() : iterable + { + if (extension_loaded('ibm_db2')) { + yield [DB2Statement::class]; + } + + yield [MysqliStatement::class]; + + if (extension_loaded('oci8')) { + yield [OCI8Statement::class]; + } + + yield [PortabilityStatement::class]; + yield [SQLAnywhereStatement::class]; + + if (extension_loaded('sqlsrv')) { + yield [SQLSrvStatement::class]; + } + } } diff --git a/tests/Doctrine/Tests/DBAL/Portability/StatementTest.php b/tests/Doctrine/Tests/DBAL/Portability/StatementTest.php index 2f35f0b64be..83775e13f2e 100644 --- a/tests/Doctrine/Tests/DBAL/Portability/StatementTest.php +++ b/tests/Doctrine/Tests/DBAL/Portability/StatementTest.php @@ -6,6 +6,7 @@ use Doctrine\DBAL\ParameterType; use Doctrine\DBAL\Portability\Connection; use Doctrine\DBAL\Portability\Statement; +use function iterator_to_array; class StatementTest extends \Doctrine\Tests\DbalTestCase { @@ -141,16 +142,11 @@ public function testSetFetchMode() public function testGetIterator() { - $data = array( - 'foo' => 'bar', - 'bar' => 'foo' - ); - - $this->wrappedStmt->expects($this->once()) - ->method('fetchAll') - ->will($this->returnValue($data)); + $this->wrappedStmt->expects($this->exactly(3)) + ->method('fetch') + ->willReturnOnConsecutiveCalls('foo', 'bar', false); - self::assertEquals(new \ArrayIterator($data), $this->stmt->getIterator()); + self::assertSame(['foo', 'bar'], iterator_to_array($this->stmt->getIterator())); } public function testRowCount()