diff --git a/CHANGELOG.md b/CHANGELOG.md index 60ab1c9b3c..f08d7e93e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## v0.11.2: + +### Critical bug fixes +- Link types were not respected in `Node.get_inputs` for SqlAlchemy [[#1271]](https://github.com/aiidateam/aiida_core/pull/1271) + + ## v0.11.1: ### Improvements diff --git a/aiida/__init__.py b/aiida/__init__.py index faac2f6c2d..aa46f583c9 100644 --- a/aiida/__init__.py +++ b/aiida/__init__.py @@ -13,7 +13,7 @@ __copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved." __license__ = "MIT license, see LICENSE.txt file." -__version__ = "0.11.1" +__version__ = "0.11.2" __authors__ = "The AiiDA team." __paper__ = """G. Pizzi, A. Cepellotti, R. Sabatini, N. Marzari, and B. Kozinsky, "AiiDA: automated interactive infrastructure and database for computational science", Comp. Mat. Sci 111, 218-230 (2016); http://dx.doi.org/10.1016/j.commatsci.2015.09.013 - http://www.aiida.net.""" __paper_short__ = """G. Pizzi et al., Comp. Mat. Sci 111, 218 (2016).""" diff --git a/aiida/backends/tests/nodes.py b/aiida/backends/tests/nodes.py index 368343d1a0..82649d6b43 100644 --- a/aiida/backends/tests/nodes.py +++ b/aiida/backends/tests/nodes.py @@ -1969,3 +1969,58 @@ def test_check_single_calc_source(self): # more than one input to the same data object! with self.assertRaises(ValueError): d1.add_link_from(calc2, link_type=LinkType.CREATE) + + def test_node_get_inputs_outputs_link_type_stored(self): + """ + Test that the link_type parameter in get_inputs and get_outputs only + returns those nodes with the correct link type for stored nodes + """ + node_origin = Node().store() + node_caller = Node().store() + node_called = Node().store() + node_input = Node().store() + node_output = Node().store() + node_return = Node().store() + + # Input links of node_origin + node_origin.add_link_from(node_caller, label='caller', link_type=LinkType.CALL) + node_origin.add_link_from(node_input, label='input', link_type=LinkType.INPUT) + + # Output links of node_origin + node_called.add_link_from(node_origin, label='called', link_type=LinkType.CALL) + node_output.add_link_from(node_origin, label='output', link_type=LinkType.CREATE) + node_return.add_link_from(node_origin, label='return', link_type=LinkType.RETURN) + + # All inputs and outputs + self.assertEquals(len(node_origin.get_inputs()), 2) + self.assertEquals(len(node_origin.get_outputs()), 3) + + # Link specific inputs + self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.CALL)), 1) + self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.INPUT)), 1) + + # Link specific outputs + self.assertEquals(len(node_origin.get_outputs(link_type=LinkType.CALL)), 1) + self.assertEquals(len(node_origin.get_outputs(link_type=LinkType.CREATE)), 1) + self.assertEquals(len(node_origin.get_outputs(link_type=LinkType.RETURN)), 1) + + def test_node_get_inputs_link_type_unstored(self): + """ + Test that the link_type parameter in get_inputs only returns those nodes with + the correct link type for unstored nodes. We don't check this analogously for + get_outputs because there is not output links cache + """ + node_origin = Node() + node_caller = Node() + node_input = Node() + + # Input links of node_origin + node_origin.add_link_from(node_caller, label='caller', link_type=LinkType.CALL) + node_origin.add_link_from(node_input, label='input', link_type=LinkType.INPUT) + + # All inputs and outputs + self.assertEquals(len(node_origin.get_inputs()), 2) + + # Link specific inputs + self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.CALL)), 1) + self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.INPUT)), 1) diff --git a/aiida/orm/implementation/general/node.py b/aiida/orm/implementation/general/node.py index 4ef41816a8..58c6052d5e 100644 --- a/aiida/orm/implementation/general/node.py +++ b/aiida/orm/implementation/general/node.py @@ -647,28 +647,24 @@ def get_outputs_dict(self, link_type=None): return new_outputs - def get_inputs(self, - node_type=None, - also_labels=False, - only_in_db=False, - link_type=None): + def get_inputs(self, node_type=None, also_labels=False, only_in_db=False, link_type=None): """ Return a list of nodes that enter (directly) in this node :param node_type: If specified, should be a class, and it filters only elements of that specific type (or a subclass of 'type') :param also_labels: If False (default) only return a list of input nodes. - If True, return a list of tuples, where each tuple has the - following format: ('label', Node), with 'label' the link label, - and Node a Node instance or subclass + If True, return a list of tuples, where each tuple has the + following format: ('label', Node), with 'label' the link label, + and Node a Node instance or subclass :param only_in_db: Return only the inputs that are in the database, - ignoring those that are in the local cache. Otherwise, return - all links. + ignoring those that are in the local cache. Otherwise, return + all links. :param link_type: Only get inputs of this link type, if None then - returns all inputs of all link types. + returns all inputs of all link types. """ if link_type is not None and not isinstance(link_type, LinkType): - raise TypeError("link_type should be a LinkType object") + raise TypeError('link_type should be a LinkType object') inputs_list = self._get_db_input_links(link_type=link_type) @@ -678,19 +674,18 @@ def get_inputs(self, for label, v in self._inputlinks_cache.iteritems(): src = v[0] + input_link_type = v[1] if label in input_list_keys: - raise InternalError( - "There exist a link with the same name " - "'{}' both in the DB and in the internal " - "cache for node pk= {}!".format(label, self.pk)) - inputs_list.append((label, src)) + raise InternalError("There exist a link with the same name '{}' both in the DB " + "and in the internal cache for node pk= {}!".format(label, self.pk)) + + if link_type is None or input_link_type is link_type: + inputs_list.append((label, src)) if node_type is None: filtered_list = inputs_list else: - filtered_list = [ - i for i in inputs_list if isinstance(i[1], node_type) - ] + filtered_list = [i for i in inputs_list if isinstance(i[1], node_type)] if also_labels: return list(filtered_list) @@ -708,33 +703,33 @@ def _get_db_input_links(self, link_type): """ pass - # pylint: disable=no-else-return @override - def get_outputs(self, type=None, also_labels=False, link_type=None): + def get_outputs(self, node_type=None, also_labels=False, link_type=None): """ Return a list of nodes that exit (directly) from this node - :param type: if specified, should be a class, and it filters only - elements of that specific type (or a subclass of 'type') + :param node_type: if specified, should be a class, and it filters only + elements of that specific node_type (or a subclass of 'node_type') :param also_labels: if False (default) only return a list of input nodes. - If True, return a list of tuples, where each tuple has the - following format: ('label', Node), with 'label' the link label, - and Node a Node instance or subclass + If True, return a list of tuples, where each tuple has the + following format: ('label', Node), with 'label' the link label, + and Node a Node instance or subclass :param link_type: Only return outputs connected by links of this type. """ + if link_type is not None and not isinstance(link_type, LinkType): + raise TypeError('link_type should be a LinkType object') + outputs_list = self._get_db_output_links(link_type=link_type) - if type is None: - if also_labels: - return list(outputs_list) - else: - return [i[1] for i in outputs_list] + if node_type is None: + filtered_list = outputs_list else: - filtered_list = (i for i in outputs_list if isinstance(i[1], type)) - if also_labels: - return list(filtered_list) - else: - return [i[1] for i in filtered_list] + filtered_list = (i for i in outputs_list if isinstance(i[1], node_type)) + + if also_labels: + return list(filtered_list) + + return [i[1] for i in filtered_list] @abstractmethod def _get_db_output_links(self, link_type): diff --git a/aiida/orm/implementation/sqlalchemy/node.py b/aiida/orm/implementation/sqlalchemy/node.py index 18602f9104..ba850aea28 100644 --- a/aiida/orm/implementation/sqlalchemy/node.py +++ b/aiida/orm/implementation/sqlalchemy/node.py @@ -271,7 +271,7 @@ def _get_db_input_links(self, link_type): if link_type is not None: link_filter['type'] = link_type.value return [(i.label, i.input.get_aiida_class()) for i in - DbLink.query.filter_by(output=self.dbnode).distinct().all()] + DbLink.query.filter_by(**link_filter).distinct().all()] def _get_db_output_links(self, link_type): diff --git a/aiida/parsers/parser.py b/aiida/parsers/parser.py index cad6fd6854..aca7d9dbe5 100644 --- a/aiida/parsers/parser.py +++ b/aiida/parsers/parser.py @@ -114,20 +114,19 @@ def get_result_parameterdata_node(self): from aiida.orm.data.parameter import ParameterData from aiida.common.exceptions import NotExistent - out_parameters = self._calc.get_outputs(type=ParameterData, also_labels=True) - out_parameterdata = [i[1] for i in out_parameters - if i[0] == self.get_linkname_outparams()] + out_parameters = self._calc.get_outputs(node_type=ParameterData, also_labels=True) + out_parameter_data = [i[1] for i in out_parameters if i[0] == self.get_linkname_outparams()] - if not out_parameterdata: + if not out_parameter_data: raise NotExistent("No output .res ParameterData node found") - elif len(out_parameterdata) > 1: + elif len(out_parameter_data) > 1: from aiida.common.exceptions import UniquenessError raise UniquenessError("Output ParameterData should be found once, " "found it instead {} times" - .format(len(out_parameterdata))) + .format(len(out_parameter_data))) - return out_parameterdata[0] + return out_parameter_data[0] def get_result_keys(self): """ diff --git a/aiida/tools/dbexporters/tcod.py b/aiida/tools/dbexporters/tcod.py index c9069188eb..f8e4c4abbc 100644 --- a/aiida/tools/dbexporters/tcod.py +++ b/aiida/tools/dbexporters/tcod.py @@ -976,7 +976,7 @@ def export_cifnode(what, parameters=None, trajectory_index=None, raise ValueError("Supplied parameters are not an " "instance of ParameterData") elif calc is not None: - params = calc.get_outputs(type=ParameterData, link_type=LinkType.CREATE) + params = calc.get_outputs(node_type=ParameterData, link_type=LinkType.CREATE) if len(params) == 1: parameters = params[0] elif len(params) > 0: diff --git a/aiida/workflows/wf_XTiO3.py b/aiida/workflows/wf_XTiO3.py index c5703af52f..10ee311dd6 100644 --- a/aiida/workflows/wf_XTiO3.py +++ b/aiida/workflows/wf_XTiO3.py @@ -207,7 +207,7 @@ def final_step(self): optimal_alat = self.get_attribute("optimal_alat") opt_calc = self.get_step_calculations(self.optimize)[0] # .get_calculations()[0] - opt_e = opt_calc.get_outputs(type=ParameterData)[0].get_dict()['energy'] + opt_e = opt_calc.get_outputs(node_type=ParameterData)[0].get_dict()['energy'] self.append_to_report(x_material + "Ti03 optimal with a=" + str(optimal_alat) + ", e=" + str(opt_e)) diff --git a/docs/source/old_workflows/index.rst b/docs/source/old_workflows/index.rst index e91d35ee5f..0ad2ec0f3e 100644 --- a/docs/source/old_workflows/index.rst +++ b/docs/source/old_workflows/index.rst @@ -535,7 +535,7 @@ aside to the final optimal cell parameter value. optimal_alat = self.get_attribute("optimal_alat") opt_calc = self.get_step_calculations(self.optimize)[0] #.get_calculations()[0] - opt_e = opt_calc.get_outputs(type=ParameterData)[0].get_dict()['energy'] + opt_e = opt_calc.get_outputs(node_type=ParameterData)[0].get_dict()['energy'] self.append_to_report(x_material+"Ti03 optimal with a="+str(optimal_alat)+", e="+str(opt_e)) @@ -741,7 +741,7 @@ phonon vibrational frequncies for some XTiO3 materials, namely Ba, Sr and Pb. run_ph_calcs = self.get_step_calculations(self.run_ph) #.get_calculations() for c in run_ph_calcs: - dm = c.get_outputs(type=ParameterData)[0].get_dict()['dynamical_matrix_1'] + dm = c.get_outputs(node_type=ParameterData)[0].get_dict()['dynamical_matrix_1'] self.append_to_report("Point q: {0} Frequencies: {1}".format(dm['q_point'],dm['frequencies'])) self.next(self.exit)