Skip to content

Commit

Permalink
Merge pull request #2981 from oesteban/tst/parametrize-join-expansion
Browse files Browse the repository at this point in the history
TST: Parametrize JoinNode expansion tests over config ``needed_outputs``
  • Loading branch information
oesteban authored Aug 1, 2019
2 parents 5965d45 + 2e9ecc1 commit 7262b24
Showing 1 changed file with 34 additions and 26 deletions.
60 changes: 34 additions & 26 deletions nipype/pipeline/engine/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from __future__ import (print_function, division, unicode_literals,
absolute_import)
from builtins import open
import pytest

from .... import config
from ... import engine as pe
from ....interfaces import base as nib
from ....interfaces.utility import IdentityInterface, Function, Merge
Expand Down Expand Up @@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
output1 = nib.traits.Int(desc='ouput')


class IncrementInterface(nib.BaseInterface):
class IncrementInterface(nib.SimpleInterface):
input_spec = IncrementInputSpec
output_spec = IncrementOutputSpec

def _run_interface(self, runtime):
runtime.returncode = 0
self._results['output1'] = self.inputs.input1 + self.inputs.inc
return runtime

def _list_outputs(self):
outputs = self._outputs().get()
outputs['output1'] = self.inputs.input1 + self.inputs.inc
return outputs


_sums = []

Expand All @@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
operands = nib.traits.List(nib.traits.Int, desc='operands')


class SumInterface(nib.BaseInterface):
class SumInterface(nib.SimpleInterface):
input_spec = SumInputSpec
output_spec = SumOutputSpec

def _run_interface(self, runtime):
runtime.returncode = 0
return runtime

def _list_outputs(self):
global _sum
global _sum_operands
outputs = self._outputs().get()
outputs['operands'] = self.inputs.input1
_sum_operands.append(outputs['operands'])
outputs['output1'] = sum(self.inputs.input1)
_sums.append(outputs['output1'])
return outputs
runtime.returncode = 0
self._results['operands'] = self.inputs.input1
self._results['output1'] = sum(self.inputs.input1)
_sum_operands.append(self.inputs.input1)
_sums.append(sum(self.inputs.input1))
return runtime


_set_len = None
Expand Down Expand Up @@ -148,35 +142,48 @@ def _list_outputs(self):
return outputs


def test_join_expansion(tmpdir):
@pytest.mark.parametrize('needed_outputs', ['true', 'false'])
def test_join_expansion(tmpdir, needed_outputs):
global _sums
global _sum_operands
global _products
tmpdir.chdir()

# Clean up, just in case some other test modified them
_products = []
_sum_operands = []
_sums = []

prev_state = config.get('execution', 'remove_unnecessary_outputs')
config.set('execution', 'remove_unnecessary_outputs', needed_outputs)
# Make the workflow.
wf = pe.Workflow(name='test')
# the iterated input node
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
inputspec.iterables = [('n', [1, 2])]
# a pre-join node in the iterated path
pre_join1 = pe.Node(IncrementInterface(), name='pre_join1')
wf.connect(inputspec, 'n', pre_join1, 'input1')
# another pre-join node in the iterated path
pre_join2 = pe.Node(IncrementInterface(), name='pre_join2')
wf.connect(pre_join1, 'output1', pre_join2, 'input1')
# the join node
join = pe.JoinNode(
SumInterface(),
joinsource='inputspec',
joinfield='input1',
name='join')
wf.connect(pre_join2, 'output1', join, 'input1')
# an uniterated post-join node
post_join1 = pe.Node(IncrementInterface(), name='post_join1')
wf.connect(join, 'output1', post_join1, 'input1')
# a post-join node in the iterated path
post_join2 = pe.Node(ProductInterface(), name='post_join2')
wf.connect(join, 'output1', post_join2, 'input1')
wf.connect(pre_join1, 'output1', post_join2, 'input2')

wf.connect([
(inputspec, pre_join1, [('n', 'input1')]),
(pre_join1, pre_join2, [('output1', 'input1')]),
(pre_join1, post_join2, [('output1', 'input2')]),
(pre_join2, join, [('output1', 'input1')]),
(join, post_join1, [('output1', 'input1')]),
(join, post_join2, [('output1', 'input1')]),
])
result = wf.run()

# the two expanded pre-join predecessor nodes feed into one join node
Expand All @@ -185,8 +192,8 @@ def test_join_expansion(tmpdir):
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
# Nipype factors away the IdentityInterface.
assert len(
result.nodes()) == 8, "The number of expanded nodes is incorrect."
assert len(result.nodes()) == 8, "The number of expanded nodes is incorrect."

# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
assert len(_sums) == 1, "The number of join outputs is incorrect"
assert _sums[
Expand All @@ -197,6 +204,7 @@ def test_join_expansion(tmpdir):
# there are two iterations of the post-join node in the iterable path
assert len(_products) == 2,\
"The number of iterated post-join outputs is incorrect"
config.set('execution', 'remove_unnecessary_outputs', prev_state)


def test_node_joinsource(tmpdir):
Expand Down

0 comments on commit 7262b24

Please sign in to comment.