-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Print train/val/test sets with labels #91
base: develop_2023
Are you sure you want to change the base?
Conversation
@PanayotisManganaris thanks for contributing to refactoring! I think it is necessary because especially the dataloader helper functions are sort of outgrowing themselves and have gotten pretty complex. So to summarize these changes to make sure I understand: the main goal of this PR is to log instance ids along with model outputs, so the changes cover:
I'm wondering if we should just use this opportunity to overhaul the model API. Maybe we have to think about how to do it in a way that doesn't break backwards compatibility for people though. What do you think of this alternate approach? Instead of returning {
"id": "JVASP-0001",
"graph": dgl.DGLGraph(...),
"line_graph": dgl.DGLGraph(...),
"target": 1.0, # e.g. formation energy in eV/at
} Actually, it may be cleaner to keep inputs and targets in separate dictionaries. In this case, x = {
"graph": dgl.DGLGraph(...),
"line_graph": dgl.DGLGraph(...),
}
y = {
"id": "JVASP-0001",
"energy_peratom": 1.0, # e.g. formation energy in eV/at
"forces": torch.ones(n_atoms, 3), # forces in eV/Angstrom
} this will require reworking the loss and evaluation code, but I think the granularity could be nice, especially if you want to treat different targets with different types of losses Alternatively, we can keep the inputs as a tuple of To simplify things, maybe we could rewrite the batch collators and Thoughts? @knc6 ? |
Thanks Brian! Your summary is appropriate. I like the double tuple unpacking might be significantly faster than a dictionary lookup. I only worry because training could involve hundreds of lookups. if you agree, then I think using function overloading with |
I agree |
@knc6 so in the current version for atomwise outputs, the dataset stores the forces and replicated per-atom formation energy as node features, right? So I think if we move those targets out of the graph attributes and pack them into a dictionary it should be straightforward enough to batch them ourselves. in some prototyping code I had something like def collate_line_graph(
samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]]
):
"""Dataloader helper to batch graphs cross `samples`."""
graphs, line_graphs, targets = map(list, zip(*samples))
energy = torch.tensor([t["energy"] for t in targets])
forces = torch.cat([t["forces"] for t in targets], dim=0)
stresses = torch.stack([t["stresses"] for t in targets])
targets = dict(total_energy=energy, forces=forces, stresses=stresses)
return dgl.batch(graphs), dgl.batch(line_graphs), targets this way we wouldn't have to replicate crystal-level outputs to pack them into node attributes I think we should also consider the usability of the data format returned by the model itself. One option is to return a I guess a drawback is that this changes the API, so maybe we would want to introduce an option that preserves the current behavior? I would be in favor of supporting this behavior for single-target cases (e.g., band gap regression models) but I think I would prefer to require the dictionary interface for force field type models |
@PanayotisManganaris I think your point about performance is reasonable, but I think the usability is probably worth it. Maybe NamedTuples could be a solution if it ends up being a performance bottleneck, but I'm not sure how annoying it would be to dynamically generate those to make it easy to use with any combination of targets |
I see Brian incorporated my friend Habibur's quick fix for the printout bug earlier this weekend.
I've gone further to change the alignn.graph.StructureDataset components and the batching functions to allow more flexibility when iterating through labeled data.
The large "complicated forward methods pass tests" commit adds some conditional logic to the model's forward methods to accommodate the labels passed through the Dataset batches. I'm interested in removing some of this complexity in the future, but the use of the ignite Trainer engine makes this branching the most straightforward way to pass the model tests.
Otherwise, conditionally changing the prepare_batch function could achieve the same, but much less readable.
I hope to continue refactoring ALIGNN on a long term basis with the aim of hopefully making adoption and adaptation easier.
Thank you.