diff --git a/src/AlterableResultIterator.php b/src/AlterableResultIterator.php index 2ddef8a9..a07b705f 100644 --- a/src/AlterableResultIterator.php +++ b/src/AlterableResultIterator.php @@ -222,6 +222,9 @@ public function take($offset, $limit) */ public function count() { + if ($this->resultIterator instanceof \Countable && $this->alterations->count() === 0) { + return $this->resultIterator->count(); + } return count($this->toArray()); } diff --git a/src/InnerResultIterator.php b/src/InnerResultIterator.php index d280a75e..3abcf01c 100644 --- a/src/InnerResultIterator.php +++ b/src/InnerResultIterator.php @@ -102,6 +102,7 @@ protected function executeQuery(): void $this->fetchStarted = true; } + private $count = null; /** * Counts found records (this is the number of records fetched, taking into account the LIMIT and OFFSET settings). * @@ -113,30 +114,14 @@ public function count() return $this->count; } - if ($this->tdbmService->getConnection()->getDatabasePlatform() instanceof MySqlPlatform) { + if ($this->fetchStarted && $this->tdbmService->getConnection()->getDatabasePlatform() instanceof MySqlPlatform) { // Optimisation: we don't need a separate "count" SQL request in MySQL. - return $this->getRowCountViaRowCountFunction(); - } else { - return $this->getRowCountViaSqlQuery(); - } - } - - private $count = null; - - /** - * Get the row count from the rowCount function (only works with MySQL) - */ - private function getRowCountViaRowCountFunction(): int - { - if (!$this->fetchStarted) { - $this->executeQuery(); + $this->count = $this->statement->rowCount(); + return $this->count; } - - $this->count = $this->statement->rowCount(); - return $this->count; + return $this->getRowCountViaSqlQuery(); } - /** * Makes a separate SQL query to compute the row count. * (not needed in MySQL) @@ -147,7 +132,7 @@ private function getRowCountViaSqlQuery(): int $this->logger->debug('Running count SQL request: '.$countSql); - $this->count = $this->tdbmService->getConnection()->fetchColumn($countSql, $this->parameters); + $this->count = (int) $this->tdbmService->getConnection()->fetchColumn($countSql, $this->parameters); return $this->count; } diff --git a/tests/TDBMDaoGeneratorTest.php b/tests/TDBMDaoGeneratorTest.php index 507dbf18..11f9df99 100644 --- a/tests/TDBMDaoGeneratorTest.php +++ b/tests/TDBMDaoGeneratorTest.php @@ -961,6 +961,17 @@ public function testPageJsonEncode(): void $this->assertCount(1, $msgDecoded); } + /** + * @depends testDaoGeneration + */ + public function testInnerResultIteratorCountAfterFetch(): void + { + $userDao = new TestUserDao($this->tdbmService); + $users = $userDao->getUsersByLoginStartingWith('j')->take(0, 4); + $users->toArray(); // We force to fetch + $this->assertEquals(3, $users->count()); + } + /** * @depends testDaoGeneration */