Source code for aiida_hubbard.workflows.hp.parallelize_atoms

# -*- coding: utf-8 -*-
"""Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms."""
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import WorkChain, while_
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_hubbard.utils.general import distribute_base_workchains

[docs]PwCalculation = CalculationFactory('quantumespresso.pw')
[docs]HpCalculation = CalculationFactory('quantumespresso.hp')
[docs]HpBaseWorkChain = WorkflowFactory('quantumespresso.hp.base')
[docs]HpParallelizeQpointsWorkChain = WorkflowFactory('quantumespresso.hp.parallelize_qpoints')
[docs]class HpParallelizeAtomsWorkChain(WorkChain): """Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms.""" @classmethod
[docs] def define(cls, spec): """Define the process specification.""" # yapf: disable super().define(spec) spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir')) spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False)) spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False) spec.input( 'init_walltime', valid_type=int, default=3600, non_db=True, help='The walltime of the initialization `HpBaseWorkChain` in seconds (default: 3600).' ) spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), help='If `True`, work directories of all called calculation will be cleaned at the end of execution.') spec.outline( cls.run_init, cls.inspect_init, while_(cls.should_run_atoms)( cls.run_atoms, ), cls.inspect_atoms, cls.run_final, cls.inspect_final, cls.results ) spec.expose_outputs(HpBaseWorkChain) spec.exit_code(300, 'ERROR_ATOM_WORKCHAIN_FAILED', message='A child work chain failed.') spec.exit_code(301, 'ERROR_INITIALIZATION_WORKCHAIN_FAILED', message='The child work chain failed.') spec.exit_code(302, 'ERROR_FINAL_WORKCHAIN_FAILED', message='The child work chain failed.')
[docs] def run_init(self): """Run an initialization `HpBaseWorkChain` to that will determine which kinds need to be perturbed. By performing an `initialization_only` calculation only the symmetry analysis will be performed to determine which kinds are to be perturbed. This information is parsed and can be used to determine exactly how many `HpBaseWorkChains` have to be launched in parallel. """ inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain)) inputs.only_initialization = orm.Bool(True) inputs.clean_workdir = self.inputs.clean_workdir inputs.hp.metadata.options.max_wallclock_seconds = self.inputs.init_walltime inputs.metadata.call_link_label = 'initialization' node = self.submit(HpBaseWorkChain, **inputs) self.to_context(initialization=node) self.report(f'launched initialization HpBaseWorkChain<{node.pk}>')
[docs] def inspect_init(self): """Inspect the initialization `HpBaseWorkChain`.""" workchain = self.ctx.initialization if not workchain.is_finished_ok: self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.') return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED output_params = workchain.outputs.parameters.get_dict() self.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())
[docs] def should_run_atoms(self): """Return whether there are more atoms to run.""" return len(self.ctx.hubbard_sites) > 0
[docs] def run_atoms(self): """Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms.""" parallelize_qpoints = self.inputs.parallelize_qpoints.value workflow = HpParallelizeQpointsWorkChain if parallelize_qpoints else HpBaseWorkChain max_concurrent_base_workchains_sites = [-1] * len(self.ctx.hubbard_sites) if 'max_concurrent_base_workchains' in self.inputs: max_concurrent_base_workchains_sites = distribute_base_workchains( len(self.ctx.hubbard_sites), self.inputs.max_concurrent_base_workchains.value ) for max_concurrent_base_workchains_site in max_concurrent_base_workchains_sites: site_index, site_kind = self.ctx.hubbard_sites.pop(0) do_only_key = f'perturb_only_atom({site_index})' key = f'atom_{site_index}' inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain)) inputs.clean_workdir = self.inputs.clean_workdir inputs.hp.parameters = inputs.hp.parameters.get_dict() inputs.hp.parameters['INPUTHP'][do_only_key] = True inputs.hp.parameters = orm.Dict(inputs.hp.parameters) inputs.metadata.call_link_label = key if parallelize_qpoints and max_concurrent_base_workchains_site != -1: inputs.max_concurrent_base_workchains = orm.Int(max_concurrent_base_workchains_site) node = self.submit(workflow, **inputs) self.to_context(**{key: node}) name = workflow.__name__ self.report(f'launched {name}<{node.pk}> for atomic site {site_index} of kind {site_kind}')
[docs] def inspect_atoms(self): """Inspect each parallel atom `HpBaseWorkChain`.""" for key, workchain in self.ctx.items(): if key.startswith('atom_'): if not workchain.is_finished_ok: self.report(f'child work chain {workchain} failed with status {workchain.exit_status}, aborting.') return self.exit_codes.ERROR_ATOM_WORKCHAIN_FAILED
[docs] def run_final(self): """Perform the final `HpCalculation` to collect the various components of the chi matrices.""" inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain)) inputs.hp.parent_scf = inputs.hp.parent_scf inputs.hp.parent_hp = {key: wc.outputs.retrieved for key, wc in self.ctx.items() if key.startswith('atom_')} inputs.hp.metadata.options.max_wallclock_seconds = 3600 # 1 hour is more than enough inputs.metadata.call_link_label = 'compute_hp' node = self.submit(HpBaseWorkChain, **inputs) self.to_context(compute_hp=node) self.report(f'launched HpBaseWorkChain<{node.pk}> to collect matrices')
[docs] def inspect_final(self): """Inspect the final `HpBaseWorkChain`.""" workchain = self.ctx.compute_hp if not workchain.is_finished_ok: self.report(f'final work chain {workchain} failed with status {workchain.exit_status}, aborting.') return self.exit_codes.ERROR_FINAL_WORKCHAIN_FAILED
[docs] def results(self): """Retrieve the results from the final matrix collection workchain.""" self.out_many(self.exposed_outputs(self.ctx.compute_hp, HpBaseWorkChain))
[docs] def on_terminated(self): """Clean the working directories of all child calculations if `clean_workdir=True` in the inputs.""" super().on_terminated() if self.inputs.clean_workdir.value is False: self.report('remote folders will not be cleaned') return cleaned_calcs = [] for called_descendant in self.node.called_descendants: if isinstance(called_descendant, orm.CalcJobNode): try: called_descendant.outputs.remote_folder._clean() # pylint: disable=protected-access cleaned_calcs.append(called_descendant.pk) except (IOError, OSError, KeyError): pass if cleaned_calcs: self.report(f"cleaned remote folders of calculations: {' '.join(map(str, cleaned_calcs))}")