-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsocworkchain.py
242 lines (207 loc) · 9.2 KB
/
socworkchain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from aiida.plugins import WorkflowFactory
from aiida.engine import ToContext, WorkChain, if_
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida import orm
from aiida.common import AttributeDict
PdosWorkChain = WorkflowFactory("quantumespresso.pdos")
PwBandsWorkChain = WorkflowFactory("quantumespresso.pw.bands")
class SOCWorkChain(WorkChain):
"WorkChain to compute vibrational property of a crystal."
label = "soc"
@classmethod
def define(cls, spec):
"""Define the process specification."""
# yapf: disable
super().define(spec)
spec.input('structure', valid_type=orm.StructureData,
help='The inputs structure.')
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
spec.input('properties', valid_type=orm.List, default=lambda: orm.List(),
help='The properties to calculate, used to control the logic of SOCWorkChain.')
spec.expose_inputs(PwBandsWorkChain, namespace='bands',
exclude=('clean_workdir', 'structure', 'relax'),
namespace_options={'required': False, 'populate_defaults': False,
'help': 'Inputs for the `PwBandsWorkChain`.'})
spec.expose_inputs(PdosWorkChain, namespace='pdos',
exclude=('clean_workdir', 'structure'),
namespace_options={'required': False, 'populate_defaults': False,
'help': 'Inputs for the `PdosWorkChain`.'})
spec.outline(
cls.setup,
if_(cls.should_run_bands)(
cls.run_bands,
cls.inspect_bands,
),
if_(cls.should_run_pdos)(
cls.run_pdos,
cls.inspect_pdos,
),
cls.results,
)
spec.expose_outputs(
PwBandsWorkChain, namespace='bands',
namespace_options={'required': False, 'help': 'Outputs of the `PwBandsWorkChain`.'},
)
spec.expose_outputs(
PdosWorkChain, namespace='pdos',
namespace_options={'required': False, 'help': 'Outputs of the `PdosWorkChain`.'},
)
spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_BANDS', message='the PwBandsWorkChain sub process failed')
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_PDOS', message='the PdosWorkChain sub process failed')
@classmethod
def get_builder_from_protocol(
cls,
structure,
pw_code,
dos_code,
projwfc_code,
protocol,
properties,
clean_workdir,
functional="PBE",
overrides=None,
**kwargs,
):
"""Return a builder prepopulated with inputs selected according to the protocol."""
overrides = overrides or {}
builder = cls.get_builder()
# Use only Fully Relativistic pseudos
if functional == "PBE":
family_fr_pseudo = orm.load_group("PseudoDojo/0.4/PBE/FR/stringent/upf")
else:
family_fr_pseudo = orm.load_group("PseudoDojo/0.4/PBEsol/FR/stringent/upf")
# Set spin_orbit coupling
for key in ["bands", "pdos"]:
for calc in ["scf", "nscf", "bands"]:
if calc in overrides[key]:
overrides[key][calc]["pw"][
"pseudos"
] = family_fr_pseudo.get_pseudos(structure=structure)
overrides[key][calc]["pw"]["parameters"]["SYSTEM"][
"lspinorb"
] = True
overrides[key][calc]["pw"]["parameters"]["SYSTEM"][
"noncolin"
] = True
overrides[key][calc]["pw"]["metadata"] = {
"options": {"max_wallclock_seconds": 82800}
}
# Set the structure
builder.structure = structure
bands_overrides = overrides.pop("bands", {})
# Bands workchain settings
soc_bands = PwBandsWorkChain.get_builder_from_protocol(
structure=structure,
code=pw_code,
protocol=protocol,
overrides=bands_overrides,
**kwargs,
)
# pop the inputs that are excluded from the exposed inputs of the bands workchain
soc_bands.pop("clean_workdir", None)
soc_bands.pop("structure", None)
soc_bands.pop("relax", None)
soc_bands.scf["pw"]["parameters"]["SYSTEM"].pop("nspin", None)
soc_bands.bands["pw"]["parameters"]["SYSTEM"].pop("nspin", None)
if structure.pbc != (True, True, True):
soc_bands.pop("bands_kpoints_distance")
soc_bands.update(
{"bands_kpoints": bands_overrides["bands"]["bands_kpoints"]}
)
builder.bands = soc_bands
# Pdos workchain settings
if dos_code is not None and projwfc_code is not None:
pdos_overrides = overrides.pop("pdos", {})
soc_pdos = PdosWorkChain.get_builder_from_protocol(
structure=structure,
pw_code=pw_code,
dos_code=dos_code,
projwfc_code=projwfc_code,
protocol=protocol,
overrides=pdos_overrides,
**kwargs,
)
soc_pdos.pop("clean_workdir", None)
soc_pdos.pop("structure", None)
soc_pdos.scf["pw"]["parameters"]["SYSTEM"].pop("nspin", None)
soc_pdos.nscf["pw"]["parameters"]["SYSTEM"].pop("nspin", None)
builder.pdos = soc_pdos
# Set the properties
builder.properties = orm.List(list=properties)
# Set the clean_workdir
builder.clean_workdir = orm.Bool(clean_workdir)
return builder
def setup(self):
"""Define the current structure and the properties to calculate."""
self.ctx.current_structure = self.inputs.structure
self.ctx.properties = self.inputs.properties
# logic to decide if bands should be run
self.ctx.run_bands = "bands" in self.ctx.properties
self.ctx.run_pdos = "pdos" in self.ctx.properties
def should_run_bands(self):
"""Return whether a bands calculation should be run."""
return self.ctx.run_bands
def run_bands(self):
"""Run the bands calculation."""
inputs = AttributeDict(self.exposed_inputs(PwBandsWorkChain, namespace="bands"))
inputs.metadata.call_link_label = "bands"
inputs.structure = self.ctx.current_structure
running = self.submit(PwBandsWorkChain, **inputs)
self.report(f"launching PwBandsWorkChain<{running.pk}>")
return ToContext(workchain_bands=running)
def inspect_bands(self):
"""Verify that the bands calculation finished successfully."""
workchain = self.ctx.workchain_bands
if not workchain.is_finished_ok:
self.report(
f"bands workchain failed with exit status {workchain.exit_status}"
)
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_BANDS
scf = (
workchain.get_outgoing(orm.WorkChainNode, link_label_filter="scf")
.one()
.node
)
try:
self.ctx.current_structure = workchain.outputs.primitive_structure
except AttributeError:
self.ctx.current_structure = workchain.inputs.structure
self.ctx.scf_parent_folder = scf.outputs.remote_folder
def should_run_pdos(self):
"""Return whether a pdos calculation should be run."""
return self.ctx.run_pdos
def run_pdos(self):
"""Run the pdos calculation."""
inputs = AttributeDict(self.exposed_inputs(PdosWorkChain, namespace="pdos"))
inputs.metadata.call_link_label = "pdos"
inputs.structure = self.ctx.current_structure
inputs.nscf.pw.parameters = inputs.nscf.pw.parameters.get_dict()
if hasattr(self.ctx, "scf_parent_folder"):
inputs.pop("scf")
inputs.nscf.pw.parent_folder = self.ctx.scf_parent_folder
inputs = prepare_process_inputs(PdosWorkChain, inputs)
running = self.submit(PdosWorkChain, **inputs)
self.report(f"launching PdosWorkChain<{running.pk}>")
return ToContext(workchain_pdos=running)
def inspect_pdos(self):
"""Verify that the pdos calculation finished successfully."""
workchain = self.ctx.workchain_pdos
if not workchain.is_finished_ok:
self.report(
f"pdos workchain failed with exit status {workchain.exit_status}"
)
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_PDOS
def results(self):
if self.ctx.run_bands:
self.out_many(
self.exposed_outputs(
self.ctx.workchain_bands, PwBandsWorkChain, namespace="bands"
)
)
if self.ctx.run_pdos:
self.out_many(
self.exposed_outputs(
self.ctx.workchain_pdos, PdosWorkChain, namespace="pdos"
)
)