diff --git a/src/aiida_pseudo/cli/family.py b/src/aiida_pseudo/cli/family.py index 15ae363..9535467 100644 --- a/src/aiida_pseudo/cli/family.py +++ b/src/aiida_pseudo/cli/family.py @@ -85,8 +85,17 @@ def cmd_family_cutoffs_set(family, cutoffs, stringency, unit): # noqa: D301 except ValueError as exception: raise click.BadParameter(f'`{cutoffs.name}` contains invalid JSON: {exception}', param_hint='CUTOFFS') + cutoffs_dict = {} + for element, values in data.items(): + try: + cutoffs_dict[element] = {'cutoff_wfc': values['cutoff_wfc'], 'cutoff_rho': values['cutoff_rho']} + except KeyError as exception: + raise click.BadParameter( + f'`{cutoffs.name}` is missing cutoffs for element `{element}`: {exception}', param_hint='CUTOFFS' + ) from exception + try: - family.set_cutoffs(data, stringency, unit=unit) + family.set_cutoffs(cutoffs_dict, stringency, unit=unit) except ValueError as exception: raise click.BadParameter(f'`{cutoffs.name}` contains invalid cutoffs: {exception}', param_hint='CUTOFFS') diff --git a/tests/cli/test_family.py b/tests/cli/test_family.py index 230e3bb..98dcb61 100644 --- a/tests/cli/test_family.py +++ b/tests/cli/test_family.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # pylint: disable=unused-argument,redefined-outer-name """Tests for the command `aiida-pseudo family`.""" +from copy import deepcopy import json from aiida.orm import Group @@ -31,13 +32,15 @@ def test_family_cutoffs_set(run_cli_command, get_pseudo_family, generate_cutoffs assert "Error: Missing option '-s' / '--stringency'" in result.output assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal']) - # Invalid cutoffs structure - filepath.write_text(json.dumps({'Ar': {'cutoff_rho': 300}})) + # Missing cutoffs + high_cutoffs = deepcopy(cutoffs_dict['high']) + high_cutoffs['Ar'].pop('cutoff_wfc') + filepath.write_text(json.dumps(high_cutoffs)) result = run_cli_command(cmd_family_cutoffs_set, [family.label, str(filepath), '-s', 'high'], raises=True) - assert 'Error: Invalid value for CUTOFFS:' in result.output + assert 'Error: Invalid value for CUTOFFS: ' in result.output assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal']) - # Set correct stringency + # Set the high stringency stringency = 'high' filepath.write_text(json.dumps(cutoffs_dict['high'])) result = run_cli_command(cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency]) @@ -46,6 +49,16 @@ def test_family_cutoffs_set(run_cli_command, get_pseudo_family, generate_cutoffs assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal', 'high']) assert family.get_cutoffs(stringency) == cutoffs_dict[stringency] + # Additional keys in the cutoffs should be accepted and simply ignored + stringency = 'invalid' + high_cutoffs = deepcopy(cutoffs_dict['high']) + high_cutoffs['Ar']['GME'] = 'moon' + filepath.write_text(json.dumps(high_cutoffs)) + result = run_cli_command(cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency]) + assert 'Success: set cutoffs for' in result.output + assert sorted(family.get_cutoff_stringencies()) == sorted(['low', 'normal', 'high', stringency]) + assert family.get_cutoffs(stringency) == cutoffs_dict['high'] + @pytest.mark.usefixtures('clear_db') def test_family_cutoffs_set_unit(run_cli_command, get_pseudo_family, generate_cutoffs, tmp_path):