Skip to content

Commit

Permalink
Check ACE return code and return standard output (#97)
Browse files Browse the repository at this point in the history
Modify the condition to trigger RuntimeError in the ACE training and print ACE output
  • Loading branch information
juraskov authored Apr 24, 2024
1 parent 3a5e750 commit 1f8b0ea
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions mlptrain/potentials/ace/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,34 @@ def _train(self) -> None:
p = Popen(
[shutil.which('julia'), f'{self.name}.jl'],
shell=False,
encoding='utf-8',
stdout=PIPE,
stderr=PIPE,
)
out, err = p.communicate(timeout=None)

filename_ace_out = 'ACE_output.out'

with open(filename_ace_out, 'a') as f:
f.write(f'ACE training output:\n{out}')
if err:
f.write(f'ACE training error:\n{err}')

delta_time = time() - start_time
logger.info(f'ACE training ran in {delta_time / 60:.1f} m')

if any(
(
delta_time < 0.01,
b'SYSTEM ABORT' in err,
'SYSTEM ABORT' in err,
p.returncode != 0,
not os.path.exists(f'{self.name}.json'),
)
):
raise RuntimeError(f'ACE train errored with:\n{err.decode()}\n')
raise RuntimeError(
f'ACE train errored with a return code:\n{p.returncode}\n'
f'and error:\n{err}\n'
)

for filename in (f'{self.name}_data.xyz', f'{self.name}.jl'):
os.remove(filename)
Expand Down Expand Up @@ -246,7 +258,7 @@ def _print_input(self, filename: str, **kwargs) -> None:
' asmerrs=true, weights=weights)\n'
'save_dict(save_name,'
' Dict("IP" => write_dict(IP), "info" => lsqinfo))\n'
'rmse_table(lsqinfo["errors"])\n'
'#rmse_table(lsqinfo["errors"])\n'
'println("The L2 norm of the fit is ", round(norm(lsqinfo["c"]), digits=2))\n',
file=inp_file,
)
Expand Down

0 comments on commit 1f8b0ea

Please sign in to comment.