Skip to content

Commit

Permalink
Introduce a more general solution for issue #6 in the network predict…
Browse files Browse the repository at this point in the history
… method.
  • Loading branch information
Mat092 committed Feb 20, 2024
1 parent 916053f commit 777db4f
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions NumPyNet/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,7 @@ def predict(self, X, truth=None, verbose=True):
num_data = len(X)
_truth = None

if num_data > 1:
batches = np.array_split(range(num_data), indices_or_sections=num_data // self.batch)
else:
batches = [np.array([0])]
batches = np.array_split(range(num_data), indices_or_sections=(num_data // self.batch) if (self.batch <= num_data) else 1)

begin = now()
start = begin
Expand Down

0 comments on commit 777db4f

Please sign in to comment.