Skip to content

Commit

Permalink
Merge pull request #121 from lewisbelcher/fix-torch-transform-bug
Browse files Browse the repository at this point in the history
Update torch_transform and tests
  • Loading branch information
mdbloice authored Jun 18, 2018
2 parents ae3e79e + 325b53b commit c3c9685
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
8 changes: 3 additions & 5 deletions Augmentor/Pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,11 +627,9 @@ def torch_transform(self):
"""
def _transform(image):
for operation in self.operations:
r = round(random.uniform(0, 1), 1)
if r <= operation.probability:
image = [image]
image = operation.perform_operation(image)

r = random.uniform(0, 1)
if r < operation.probability:
image = operation.perform_operation([image])[0]
return image

return _transform
Expand Down
16 changes: 12 additions & 4 deletions tests/test_torch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@ def test_torch_transform():
red[..., 0] = 255
red = Image.fromarray(red)

g = Augmentor.Operations.Greyscale(1)

p = Augmentor.Pipeline()
p.greyscale(1)

# include multiple transforms to test integration
p.greyscale(probability=1)
p.zoom(probability=1, min_factor=1.0, max_factor=1.0)
p.rotate_random_90(probability=1)

transforms = torchvision.transforms.Compose([
p.torch_transform()
])

assert red != transforms(red)
assert g.perform_operation([red]) == transforms(red)

# assert that all operations were correctly applied
result = red
for op in p.operations:
result = op.perform_operation([result])[0]
assert transforms(red) == result

0 comments on commit c3c9685

Please sign in to comment.