Skip to content

Commit

Permalink
👌 Simplify Wannier90WorkChain.get_builder_from_protocol()
Browse files Browse the repository at this point in the history
The `get_builder_from_protocol()` method of the `Wannier90WorkChain` relied on
two external functions `get_scf_builder()` and `get_nscf_builder()` to construct
the builder of the corresponding subprocesses. This had a few issues:

* It was more challenging to figure out where parameters are set, especially
  since the `get_nscf_builder()` method was again nested in the
  `get_scf_builder()` one.
* The overrides would not be properly respected.

Here we simplify how the `get_builder_from_protocol()` method in two ways:

* The `SpinType`-related inputs are stored in the same protocol overrides YAML
  which is used for the inputs related to the input arguments of the method.
  In case `SpinTyp` is equal to `NON_COLLINEAR` or `SPIN_ORBIT`, the overrides
  are adapted accordingly, but the user-specified `overrides` still take
  precedence. In order to avoid issues with calling the
  `PwBaseWorkChain.get_builder_from_protocol()` method, `SpinType.NONE` is
  passed to the method for the SCF and NSCF steps.
* Instead of using the external functions, pass the `pseudo_family` and `nbnd`
  through the overrides, and use the `get_builder_from_protocol()` method.
  Since we can't yet remove/pop inputs through the overrides, adapt the
  `kpoints` input for the NSCF after obtaining the populated builder.

Finally, note that the `protocol` was not passed to SCF and NSCF builder
builder generators. This has now been fixed.
  • Loading branch information
mbercx authored and qiaojunfeng committed Nov 23, 2023
1 parent f28e7df commit a143916
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,27 @@ retrieve_matrices:
- '*.amn'
- '*.mmn'
- '*.spn'
spin_noncollinear:
scf:
pw:
parameters:
SYSTEM:
noncolin: True
nscf:
pw:
parameters:
SYSTEM:
noncolin: True
spin_orbit:
scf:
pw:
parameters:
SYSTEM:
lspinorb: True
noncolin: True
nscf:
pw:
parameters:
SYSTEM:
lspinorb: True
noncolin: True
15 changes: 15 additions & 0 deletions src/aiida_wannier90_workflows/workflows/protocols/wannier90.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
default_inputs:
clean_workdir: False
scf:
kpoints_distance: 0.2
nscf:
kpoints_distance: 0.2
pw:
parameters:
SYSTEM:
nosym: True
noinv: True
CONTROL:
calculation: nscf
restart_mode: from_scratch
ELECTRONS:
diago_full_acc: True
startingpot: file
default_protocol: moderate
protocols:
moderate:
Expand Down
87 changes: 52 additions & 35 deletions src/aiida_wannier90_workflows/workflows/wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,6 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
get_pseudo_orbitals,
get_semicore_list,
)
from aiida_wannier90_workflows.utils.workflows.builder.generator import (
get_nscf_builder,
get_scf_builder,
)
from aiida_wannier90_workflows.utils.workflows.builder.projections import (
guess_wannier_projection_types,
)
Expand Down Expand Up @@ -359,33 +355,52 @@ def get_builder_from_protocol( # pylint: disable=unused-argument

# Adapt overrides based on input arguments
# Note: if overrides are specified, they take precedence!
argument_overrides = cls.get_protocol_overrides()
protocol_overrides = cls.get_protocol_overrides()

if plot_wannier_functions:
overrides = recursive_merge(
argument_overrides["plot_wannier_functions"], overrides
protocol_overrides["plot_wannier_functions"], overrides
)

if retrieve_hamiltonian:
overrides = recursive_merge(
argument_overrides["retrieve_hamiltonian"], overrides
protocol_overrides["retrieve_hamiltonian"], overrides
)

if retrieve_matrices:
overrides = recursive_merge(
argument_overrides["retrieve_matrices"], overrides
protocol_overrides["retrieve_matrices"], overrides
)

inputs = cls.get_protocol_inputs(protocol=protocol, overrides=overrides)

if pseudo_family is None:
if spin_type == SpinType.SPIN_ORBIT:
# I use pseudo-dojo for SOC
# Use fully relativistic PseudoDojo for SOC
pseudo_family = "PseudoDojo/0.4/PBE/FR/standard/upf"
else:
pseudo_family = Wannier90BaseWorkChain.get_protocol_inputs(
protocol=protocol
)["meta_parameters"]["pseudo_family"]
# Use the one used in Wannier90BaseWorkChain
pseudo_family = (
pseudo_family
or Wannier90BaseWorkChain.get_protocol_inputs(protocol=protocol)[
"meta_parameters"
]["pseudo_family"]
)

# As PwBaseWorkChain.get_builder_from_protocol() does not support SOC, we have to pass the
# desired parameters through the overrides. In this case we need to set the `pw.x`
# spin_type to SpinType.NONE, otherwise the builder will raise an error.
# This block should be removed once SOC is supported in PwBaseWorkChain.
if spin_type == SpinType.NON_COLLINEAR:
overrides = recursive_merge(
protocol_overrides["spin_noncollinear"], overrides
)
pw_spin_type = SpinType.NONE
elif spin_type == SpinType.SPIN_ORBIT:
overrides = recursive_merge(protocol_overrides["spin_orbit"], overrides)
pw_spin_type = SpinType.NONE
else:
pw_spin_type = spin_type

inputs = cls.get_protocol_inputs(protocol=protocol, overrides=overrides)

builder = cls.get_builder()
builder.structure = structure
Expand Down Expand Up @@ -414,48 +429,50 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
wannier_builder.pop("clean_workdir", None)
builder.wannier90 = wannier_builder._inputs(prune=True)

kpoints_distance = Wannier90BaseWorkChain.get_protocol_inputs(
protocol=protocol, overrides=wannier_overrides
)["meta_parameters"]["kpoints_distance"]

# Prepare SCF builder
scf_overrides = inputs.get("scf", {})
scf_builder = get_scf_builder(
scf_overrides["pseudo_family"] = pseudo_family
scf_builder = PwBaseWorkChain.get_builder_from_protocol(
code=codes["pw"],
structure=structure,
kpoints_distance=kpoints_distance,
pseudo_family=pseudo_family,
electronic_type=electronic_type,
spin_type=spin_type,
protocol=protocol,
overrides=scf_overrides,
electronic_type=electronic_type,
spin_type=pw_spin_type,
)
# Remove workchain excluded inputs
scf_builder["pw"].pop("structure", None)
scf_builder.pop("clean_workdir", None)
builder.scf = scf_builder._inputs(prune=True)

# Prepare NSCF builder
nscf_overrides = inputs.get("nscf", {})
nscf_overrides["pseudo_family"] = pseudo_family

num_bands = wannier_builder["wannier90"]["parameters"]["num_bands"]
exclude_bands = (
wannier_builder["wannier90"]["parameters"]
.get_dict()
.get("exclude_bands", [])
)
nbnd = num_bands + len(exclude_bands)
# Use explicit list of kpoints generated by wannier builder.
# Since the QE auto generated kpoints might be different from wannier90, here we explicitly
# generate a list of kpoint coordinates to avoid discrepencies.
kpoints = wannier_builder["wannier90"]["kpoints"]
nscf_overrides = inputs.get("nscf", {})
nscf_builder = get_nscf_builder(
nscf_overrides["pw"]["parameters"]["SYSTEM"]["nbnd"] = num_bands + len(
exclude_bands
)

nscf_builder = PwBaseWorkChain.get_builder_from_protocol(
code=codes["pw"],
structure=structure,
nbnd=nbnd,
kpoints=kpoints,
pseudo_family=pseudo_family,
electronic_type=electronic_type,
spin_type=spin_type,
protocol=protocol,
overrides=nscf_overrides,
electronic_type=electronic_type,
spin_type=pw_spin_type,
)
# Use explicit list of kpoints generated by wannier builder.
# Since the QE auto generated kpoints might be different from wannier90, here we explicitly
# generate a list of kpoint coordinates to avoid discrepancies.
nscf_builder.pop("kpoints_distance", None)
nscf_builder.kpoints = wannier_builder["wannier90"]["kpoints"]

# Remove workchain excluded inputs
nscf_builder["pw"].pop("structure", None)
nscf_builder.pop("clean_workdir", None)
Expand Down
1 change: 1 addition & 0 deletions tests/workflows/protocols/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
}
},
{"wannier90": {"auto_energy_windows_threshold": 0.01}},
{"nscf": {"pw": {"parameters": {"SYSTEM": {"calculation": "bands"}}}}},
),
)
def test_overrides(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
kpoints: 1331 kpts
kpoints_force_parity: false
pw:
code: test.quantumespresso.pw@localhost
metadata:
options:
max_wallclock_seconds: 43200
resources:
num_machines: 1
withmpi: true
parameters:
CONTROL:
calculation: nscf
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
restart_mode: from_scratch
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
electron_maxstep: 80
mixing_beta: 0.4
startingpot: file
SYSTEM:
calculation: bands
degauss: 0.01
ecutrho: 240.0
ecutwfc: 30.0
nbnd: 16
noinv: true
nosym: true
occupations: smearing
smearing: cold
pseudos:
Si: Si.pbe-n-rrkjus_psl.1.0.0.UPF
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
kpoints_distance: 0.2
kpoints_force_parity: false
pw:
code: test.quantumespresso.pw@localhost
metadata:
options:
max_wallclock_seconds: 43200
resources:
num_machines: 1
withmpi: true
parameters:
CONTROL:
calculation: nscf
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
restart_mode: from_scratch
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
electron_maxstep: 80
mixing_beta: 0.4
startingpot: file
SYSTEM:
calculation: bands
degauss: 0.01
ecutrho: 240.0
ecutwfc: 30.0
nbnd: 16
occupations: smearing
smearing: cold
pseudos:
Si: Si.pbe-n-rrkjus_psl.1.0.0.UPF
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
kpoints: 1331 kpts
kpoints_force_parity: false
pw:
code: test.quantumespresso.pw@localhost
metadata:
options:
max_wallclock_seconds: 43200
resources:
num_machines: 1
withmpi: true
parameters:
CONTROL:
calculation: nscf
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
restart_mode: from_scratch
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
electron_maxstep: 80
mixing_beta: 0.4
startingpot: file
SYSTEM:
calculation: bands
degauss: 0.01
ecutrho: 240.0
ecutwfc: 30.0
nbnd: 16
noinv: true
nosym: true
occupations: smearing
smearing: cold
pseudos:
Si: Si.pbe-n-rrkjus_psl.1.0.0.UPF
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
kpoints: 1331 kpts
kpoints_force_parity: false
pw:
code: test.quantumespresso.pw@localhost
metadata:
options:
max_wallclock_seconds: 43200
resources:
num_machines: 1
withmpi: true
parameters:
CONTROL:
calculation: nscf
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
restart_mode: from_scratch
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
electron_maxstep: 80
mixing_beta: 0.4
startingpot: file
SYSTEM:
calculation: bands
degauss: 0.01
ecutrho: 240.0
ecutwfc: 30.0
nbnd: 16
noinv: true
nosym: true
occupations: smearing
smearing: cold
pseudos:
Si: Si.pbe-n-rrkjus_psl.1.0.0.UPF

0 comments on commit a143916

Please sign in to comment.