diff --git a/nipype/pipeline/engine/tests/test_join.py b/nipype/pipeline/engine/tests/test_join.py index 77fc0f2fdf..8553b79dc9 100644 --- a/nipype/pipeline/engine/tests/test_join.py +++ b/nipype/pipeline/engine/tests/test_join.py @@ -6,7 +6,9 @@ from __future__ import (print_function, division, unicode_literals, absolute_import) from builtins import open +import pytest +from .... import config from ... import engine as pe from ....interfaces import base as nib from ....interfaces.utility import IdentityInterface, Function, Merge @@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec): output1 = nib.traits.Int(desc='ouput') -class IncrementInterface(nib.BaseInterface): +class IncrementInterface(nib.SimpleInterface): input_spec = IncrementInputSpec output_spec = IncrementOutputSpec def _run_interface(self, runtime): runtime.returncode = 0 + self._results['output1'] = self.inputs.input1 + self.inputs.inc return runtime - def _list_outputs(self): - outputs = self._outputs().get() - outputs['output1'] = self.inputs.input1 + self.inputs.inc - return outputs - _sums = [] @@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec): operands = nib.traits.List(nib.traits.Int, desc='operands') -class SumInterface(nib.BaseInterface): +class SumInterface(nib.SimpleInterface): input_spec = SumInputSpec output_spec = SumOutputSpec def _run_interface(self, runtime): - runtime.returncode = 0 - return runtime - - def _list_outputs(self): global _sum global _sum_operands - outputs = self._outputs().get() - outputs['operands'] = self.inputs.input1 - _sum_operands.append(outputs['operands']) - outputs['output1'] = sum(self.inputs.input1) - _sums.append(outputs['output1']) - return outputs + runtime.returncode = 0 + self._results['operands'] = self.inputs.input1 + self._results['output1'] = sum(self.inputs.input1) + _sum_operands.append(self.inputs.input1) + _sums.append(sum(self.inputs.input1)) + return runtime _set_len = None @@ -148,9 +142,20 @@ def _list_outputs(self): return outputs -def test_join_expansion(tmpdir): +@pytest.mark.parametrize('needed_outputs', ['true', 'false']) +def test_join_expansion(tmpdir, needed_outputs): + global _sums + global _sum_operands + global _products tmpdir.chdir() + # Clean up, just in case some other test modified them + _products = [] + _sum_operands = [] + _sums = [] + + prev_state = config.get('execution', 'remove_unnecessary_outputs') + config.set('execution', 'remove_unnecessary_outputs', needed_outputs) # Make the workflow. wf = pe.Workflow(name='test') # the iterated input node @@ -158,25 +163,27 @@ def test_join_expansion(tmpdir): inputspec.iterables = [('n', [1, 2])] # a pre-join node in the iterated path pre_join1 = pe.Node(IncrementInterface(), name='pre_join1') - wf.connect(inputspec, 'n', pre_join1, 'input1') # another pre-join node in the iterated path pre_join2 = pe.Node(IncrementInterface(), name='pre_join2') - wf.connect(pre_join1, 'output1', pre_join2, 'input1') # the join node join = pe.JoinNode( SumInterface(), joinsource='inputspec', joinfield='input1', name='join') - wf.connect(pre_join2, 'output1', join, 'input1') # an uniterated post-join node post_join1 = pe.Node(IncrementInterface(), name='post_join1') - wf.connect(join, 'output1', post_join1, 'input1') # a post-join node in the iterated path post_join2 = pe.Node(ProductInterface(), name='post_join2') - wf.connect(join, 'output1', post_join2, 'input1') - wf.connect(pre_join1, 'output1', post_join2, 'input2') + wf.connect([ + (inputspec, pre_join1, [('n', 'input1')]), + (pre_join1, pre_join2, [('output1', 'input1')]), + (pre_join1, post_join2, [('output1', 'input2')]), + (pre_join2, join, [('output1', 'input1')]), + (join, post_join1, [('output1', 'input1')]), + (join, post_join2, [('output1', 'input1')]), + ]) result = wf.run() # the two expanded pre-join predecessor nodes feed into one join node @@ -185,8 +192,8 @@ def test_join_expansion(tmpdir): # the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join # node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes. # Nipype factors away the IdentityInterface. - assert len( - result.nodes()) == 8, "The number of expanded nodes is incorrect." + assert len(result.nodes()) == 8, "The number of expanded nodes is incorrect." + # the join Sum result is (1 + 1 + 1) + (2 + 1 + 1) assert len(_sums) == 1, "The number of join outputs is incorrect" assert _sums[ @@ -197,6 +204,7 @@ def test_join_expansion(tmpdir): # there are two iterations of the post-join node in the iterable path assert len(_products) == 2,\ "The number of iterated post-join outputs is incorrect" + config.set('execution', 'remove_unnecessary_outputs', prev_state) def test_node_joinsource(tmpdir):