diff --git a/NumPyNet/network.py b/NumPyNet/network.py index 9000d2e..ce3990a 100644 --- a/NumPyNet/network.py +++ b/NumPyNet/network.py @@ -653,10 +653,7 @@ def predict(self, X, truth=None, verbose=True): num_data = len(X) _truth = None - if num_data > 1: - batches = np.array_split(range(num_data), indices_or_sections=num_data // self.batch) - else: - batches = [np.array([0])] + batches = np.array_split(range(num_data), indices_or_sections=(num_data // self.batch) if (self.batch <= num_data) else 1) begin = now() start = begin