Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print train/val/test sets with labels #91

Open
wants to merge 17 commits into
base: develop_2023
Choose a base branch
from

Conversation

pmiam
Copy link

@pmiam pmiam commented Jan 23, 2023

I see Brian incorporated my friend Habibur's quick fix for the printout bug earlier this weekend.

I've gone further to change the alignn.graph.StructureDataset components and the batching functions to allow more flexibility when iterating through labeled data.

The large "complicated forward methods pass tests" commit adds some conditional logic to the model's forward methods to accommodate the labels passed through the Dataset batches. I'm interested in removing some of this complexity in the future, but the use of the ignite Trainer engine makes this branching the most straightforward way to pass the model tests.

Otherwise, conditionally changing the prepare_batch function could achieve the same, but much less readable.

I hope to continue refactoring ALIGNN on a long term basis with the aim of hopefully making adoption and adaptation easier.
Thank you.

@bdecost
Copy link
Collaborator

bdecost commented Jan 23, 2023

@PanayotisManganaris thanks for contributing to refactoring! I think it is necessary because especially the dataloader helper functions are sort of outgrowing themselves and have gotten pretty complex.

So to summarize these changes to make sure I understand: the main goal of this PR is to log instance ids along with model outputs, so the changes cover:

  • add the instance id to the return value of StructureDataset.__getitem__
  • modify the batch collators accordingly
  • also modify model.forward because we don't want the complexity of splitting up the batch data differently for the ignite training engine and evaluation callbacks that may need the instance ids

I'm wondering if we should just use this opportunity to overhaul the model API. Maybe we have to think about how to do it in a way that doesn't break backwards compatibility for people though.

What do you think of this alternate approach? Instead of returning tuples, the dataloader could return dicts containing the data, like

{
    "id": "JVASP-0001",
    "graph": dgl.DGLGraph(...),
    "line_graph": dgl.DGLGraph(...),
    "target": 1.0, # e.g. formation energy in eV/at
}

Actually, it may be cleaner to keep inputs and targets in separate dictionaries. In this case, "id" maybe more naturally goes with the targets. Then for multi-target training (like energy+force training) we can have

x = {
    "graph": dgl.DGLGraph(...),
    "line_graph": dgl.DGLGraph(...),
}
y = {
    "id": "JVASP-0001",
    "energy_peratom": 1.0, # e.g. formation energy in eV/at
    "forces": torch.ones(n_atoms, 3), # forces in eV/Angstrom
}

this will require reworking the loss and evaluation code, but I think the granularity could be nice, especially if you want to treat different targets with different types of losses

Alternatively, we can keep the inputs as a tuple of DGLGraphs. This might be kind of nice for writing code that works equally well with ALIGNN and CGCNN or similar. And we might be able to do it without breaking changes?

To simplify things, maybe we could rewrite the batch collators and forward methods using singledispatch.

Thoughts? @knc6 ?

@pmiam
Copy link
Author

pmiam commented Jan 23, 2023

Thanks Brian! Your summary is appropriate.

I like the double dict approach. Currently, my batch functions return a 2-tuple of xtpl and y. This is particularly awkward because xtpl can vary in size depending on the number of alignn layers requested in config and/or the choice of model. Dictionaries would make forwards much simpler.

tuple unpacking might be significantly faster than a dictionary lookup. I only worry because training could involve hundreds of lookups. if you agree, then I think using function overloading with singledispatch is very clever, and could probably be kept out of the hair of model developers.

@knc6
Copy link
Collaborator

knc6 commented Jan 23, 2023

I agree dict is a much cleaner approach. We are actually using a dict structure while working with the atomwise model: https://github.com/usnistgov/alignn/blob/main/alignn/train.py#L368 . It will be nice to have similar format in the StructureDataset itself. Can it create issues with dgl-batching? we'll have to check that carefully.

@bdecost
Copy link
Collaborator

bdecost commented Jan 23, 2023

@knc6 so in the current version for atomwise outputs, the dataset stores the forces and replicated per-atom formation energy as node features, right? So dgl.batch handles stacking features for us https://github.com/usnistgov/alignn/blob/main/alignn/graphs.py#L618

I think if we move those targets out of the graph attributes and pack them into a dictionary it should be straightforward enough to batch them ourselves.

in some prototyping code I had something like

def collate_line_graph(
    samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor]]
    ):
    """Dataloader helper to batch graphs cross `samples`."""

    graphs, line_graphs, targets = map(list, zip(*samples))
    
    energy = torch.tensor([t["energy"] for t in targets])
    forces = torch.cat([t["forces"] for t in targets], dim=0)
    stresses = torch.stack([t["stresses"] for t in targets])
    targets = dict(total_energy=energy, forces=forces, stresses=stresses)
    
    return dgl.batch(graphs), dgl.batch(line_graphs), targets

this way we wouldn't have to replicate crystal-level outputs to pack them into node attributes

I think we should also consider the usability of the data format returned by the model itself.

One option is to return a Dict[str,Tensor] with the same keys as what we're proposing for the dataset. This has the advantage that we don't need to treat graph-level values like total energies or stress tensors specially, and we can write custom loss functions that use those keys to do whatever is appropriate.

I guess a drawback is that this changes the API, so maybe we would want to introduce an option that preserves the current behavior? I would be in favor of supporting this behavior for single-target cases (e.g., band gap regression models) but I think I would prefer to require the dictionary interface for force field type models

@bdecost
Copy link
Collaborator

bdecost commented Jan 23, 2023

@PanayotisManganaris I think your point about performance is reasonable, but I think the usability is probably worth it. Maybe NamedTuples could be a solution if it ends up being a performance bottleneck, but I'm not sure how annoying it would be to dynamically generate those to make it easy to use with any combination of targets

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants