Skip to content

Commit

Permalink
Contact Experiment Completed, Early Stopping, Evaluation Code for Cla…
Browse files Browse the repository at this point in the history
…ssification
  • Loading branch information
DanielChaseButterfield authored Sep 9, 2024
2 parents 16c73cf + e6f0443 commit 53b1046
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 141 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ jobs:
uses: snickerbockers/submodules-init@v4
- name: Test with unittest
run: |
export MIHGNN_UNITTEST_SKIP_GITHUB_ACTION_CRASHERS=True
python -m unittest discover tests/ -v
29 changes: 9 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# MI-HGNN for contact estimation/classification
This repository implements a Morphologically-inspired Heterogeneous graph neural network for estimating contact information on the feet of a quadruped robot.
This repository implements a Morphology-Inspired Heterogeneous Graph Neural Network (MI-HGNN) for estimating contact information on the feet of a quadruped robot.

## Installation
To get started, setup a Conda Python environment with Python=3.11:
Expand All @@ -13,7 +13,7 @@ Then, install the library (and dependencies) with the following command:
pip install .
```

Note, if you have any issues with setup, refer to the `environment_files/README.md` so you can install the exact libraries we used.
Note, if you have any issues with setup, refer to `environment_files/README.md` so you can install the exact libraries we used.

## URDF Download
The necessary URDF files are part of git submodules in this repository, so run the following commands to download them:
Expand All @@ -31,46 +31,35 @@ in the `train_model` function.
The model weights will be saved in the following folder, based on the model
type and the randomly chosen model name (which is output in the terminal when training begins):
```
<repository base directory>/models/<model-type>-<model_name>/
<repository base directory>/models/<model_name>/
```
There will be the six models saved, one with the final model weights, and five with the best validation losses during training.

### LinTzuYaun Contact Dataset Model
### Contact Detection (Classification) Experiment

To train a model from the dataset from [Legged Robot State Estimation using Invariant Kalman Filtering and Learned Contact Events](https://arxiv.org/abs/2106.15713), run the following command within your Conda environment:
To train a model from the dataset from [MorphoSymm-Replication]([https://arxiv.org/abs/2106.15713](https://github.com/lunarlab-gatech/MorphoSymm-Replication/releases/tag/RepBugFixes)), run the following command within your Conda environment. Feel free to edit the model parameters within the file itself:

```
python research/train_classification.py
```

If you want to customize the model used, the number of layers, or the hidden size, feel free to change the corresponding variables.

To evaluate this model, edit `evaluator_classification.py` to specify which model to evaluate, its type, and the number of dataset entries to consider. Then, run the following command:
To evaluate a model, edit `evaluator_classification.py` to specify which model to evaluate. Then, run the following command:

```
python research/evaluator_classification.py
```

The visualization of the predicted and GT values will be found in a file called `model_eval_results.pdf` in the same directory as your model weights.

### QuadSDK & Real World GRF Model
### GRF (Regression) Model

Not Yet Implemented.

### Your Own Custom Model

Tutorial not yet written.


## Changing the model type
Currently, two model types are supported:
- `mlp`
- `heterogeneous_gnn`
To change the model type, please change the `model_type` parameter in the `train.py` and `evaluator.py` files.

## Editing this repository
If you want to make changes to the source files, feel free to edit them in the `src/grfgnn` folder, and then
If you want to make changes to the source files, feel free to edit them in the `src/mi_hgnn/' folder, and then
rebuild the library following the instructions in [#Installation](#installation).

## Paper Replication
To replicate our paper results with the model weights we trained, see `paper/README.md`.
To replicate our paper results with the model weights we trained, see `paper/README.md`.
58 changes: 31 additions & 27 deletions paper/README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
# Paper Replication

This directory provides the model weights for all of our MI-HGNN models (and MLP models) referenced in the paper. Whenever a specific trained model is referenced in this README (for example, `ancient-salad-5`), it will be highlighted as shown, and there will be a folder on [Google Drive](https://drive.google.com/drive/folders/1NS5H_JIXW-pORyQUR15-t3mzf2byG26v?usp=sharing) with its name. In that folder will be its weights after the full 30 epochs of training, which were used to generate the paper results.
This directory provides the model weights for all of our MI-HGNN models referenced in the paper. Whenever a specific trained model is referenced in this README (for example, `ancient-salad-5`), it will be highlighted as shown, and there will be a folder on Georgia Tech's [Dropbox](https://www.dropbox.com/scl/fo/8p165xcfbdfwlcr3jx7tb/ABoxs5BOEXsQnJgFXF_Mjcc?rlkey=znrs7oyu29qsswpd3a5r55zk8&st=53v30ys3&dl=0) with its name. Unless otherwise specified, the model weights used for the paper were those trained the longest (have highest `epoch=` number in their .ckpt file).

To find the name of a specific model referenced in the paper or to replicate the results, refer to the following sections below which correspond to paper sections.

## Contact Detection (Classification) Experiment

To replicate the results of these experiments on your own end, input the checkpoint path into the `evaluator_classification.py` file found in the `research` directory of this repository.
Our models trained during this experiment can be found in the table below. For more details, see the `contact_experiment.csv` file in this directory. To evaluate the model metrics on your own end, input the checkpoint path into the `evaluator_classification.py` file found in the `research` directory of this repository:

As of the time of this commit, the main experiment has not been completed yet.
| Number of Layers | Hidden Sizes | Seed | State Accuracy (Test) | Model Name |
| ---------------- | ------------ | ---- |---------------------- | --------------------- |
| 8 | 128 | 0 | 0.874120593070984 | `gentle-morning-4` |
| 8 | 128 | 1 | 0.895811080932617 | `leafy-totem-5` |
| 8 | 128 | 2 | 0.868574500083923 | `different-oath-6` |
| 8 | 128 | 3 | 0.878039181232452 | `hopeful-mountain-7` |
| 8 | 128 | 4 | 0.855807065963745 | `revived-durian-8` |
| 8 | 128 | 5 | 0.875732064247131 | `robust-planet-9` |
| 8 | 128 | 6 | 0.883218884468079 | `super-microwave-10` |
| 8 | 128 | 7 | 0.880922436714172 | `valiant-dawn-11` |

The baseline models we compared to (ECNN, CNN-aug, CNN) were trained on this release: [MorphoSymm-Replication -> With Bug Fixes](https://github.com/lunarlab-gatech/MorphoSymm-Replication/releases/tag/RepBugFixes). See that repository for information on accessing those model weights, and replicating the Contact Detection Experiment Figure seen in our paper.

### Abalation Study

For this paper, we conducted an abalation study to see how parameter-efficient our model is. In the paper, we give the each trained model's layer number, hidden size, parameter number, and finally, the state accuracy on the test set. The table below relates these parameters to specific trained model names so that you can find the exact checkpoint weights for each model.

| Number of Layers | Hidden Sizes | Number of Parameters | State Accuracy (Test) | Model Name |
| ---------------- | ------------ | ---------------------| --------------------- | --------------------- |
| 4 | 5 | | | |
| 4 | 10 | | | |
| 4 | 25 | | | |
| 4 | 50 | | | |
| 4 | 128 | | | |
| 8 | 50 | | | |
| 8 | 128 | | | |
| 8 | 256 | | | |
| 12 | 50 | | | |
| 12 | 128 | | | |
| 12 | 256 | | | |

### Side note on normalization

In our paper, we mention that we found that our MI-HGNN model performed better without entry-wise normalization. We found this by running the two models below. This wasn't an exhausive experiment, which is why it only deserved a short reference in the paper. However, you can see that for this specific configuration of layer size and hidden size, our MI-HGNN model has a _______ increase in accuracy when disabling the entry-wise normalization used in this experiment.

| Number of Layers | Hidden Sizes | Normalization | State Accuracy (Test) | Model Name |
| ---------------- | ------------ | --------------| --------------------- | ----------------- |
| 12 | 128 | False | | |
| 12 | 128 | True | | |
We conducted an abalation study to see how parameter-efficient our model is. In the paper, we give the each trained model's layer number, hidden size, parameter number, and finally, the state accuracy on the test set. Here those values are associated with the model's name in the table below. For more details, see the `contact_experiment_ablation.csv` file in this directory.

| Number of Layers | Hidden Sizes | Model Name |
| ---------------- | ------------ | ---------------------- |
| 4 | 5 | `prime-water-16` |
| 4 | 10 | `driven-shape-17` |
| 4 | 25 | `autumn-terrain-18` |
| 4 | 50 | `comfy-dawn-19` |
| 4 | 128 | `prime-butterfly-20` |
| 4 | 256 | `youthful-galaxy-21` |
| 8 | 50 | `exalted-mountain-22` |
| 8 | 128 | `serene-armadillo-23` |
| 8 | 256 | `playful-durian-12` |
| 12 | 50 | `twilight-armadillo-15`|
| 12 | 128 | `sparkling-music-14` |
| 12 | 256 | `stoic-mountain-13` |


## Ground Reaction Force Estimation (Regression) Experiment

Expand Down
9 changes: 9 additions & 0 deletions paper/contact_experiment.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"Name","State","Notes","User","Tags","Created","GPU Type","Runtime","Sweep","ID","Group","Job Type","Updated","End Time","Hostname","Description","Commit","GitHub","GPU Count","batch_size","hidden_channels","lr","normalize","num_layers","num_parameters","optimizer","regression","seed","epoch","test_Accuracy","test_F1_Score_Leg_0","test_F1_Score_Leg_1","test_F1_Score_Leg_2","test_F1_Score_Leg_3","test_F1_Score_Leg_Avg","test_CE_loss","train_Accuracy","train_CE_loss","train_F1_Score_Leg_0","train_F1_Score_Leg_1","train_F1_Score_Leg_2","train_F1_Score_Leg_3","train_F1_Score_Leg_Avg","trainer/global_step","val_Accuracy","val_CE_loss","val_F1_Score_Leg_0","val_F1_Score_Leg_1","val_F1_Score_Leg_2","val_F1_Score_Leg_3","val_F1_Score_Leg_Avg"
"valiant-dawn-11","finished","-","","","2024-09-01T02:45:31.000Z","NVIDIA GeForce RTX 4090","7921","","qc5v7hq2","","","2024-09-01T04:57:32.000Z","2024-09-01T04:57:32.000Z","aellnar08106d","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/state-estimation-gnn/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","1","30","128","0.0001","true","8","1585282","adam","false","7","11","0.8809224367141724","0.9493856530476056","0.935029145210738","0.9372308446878636","0.9310457129322596","0.9381728389696168","0.24830123724984204","0.9666666388511658","0.032949131727218625","1","1","0.9090909090909092","1","0.9772727272727272","197769","0.892520546913147","0.34584198461352583","0.9549937648792652","0.9352832408808204","0.9484526033628464","0.9479270099132387","0.9466641547590428"
"super-microwave-10","finished","-","","","2024-09-01T02:42:58.000Z","NVIDIA GeForce RTX 4090 Laptop GPU","12255","","m3h353sj","","","2024-09-01T06:07:13.000Z","2024-09-01T06:07:13.000Z","aellnar08116l","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/state-estimation-gnn/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","1","30","128","0.0001","true","8","1585282","adam","false","6","11","0.8832188844680786","0.9426359403779891","0.9379542360160368","0.9349237797511736","0.9363857798233686","0.937974933992142","0.22593691841093608","1","0.0007627168670296669","1","1","1","1","1","197769","0.8945062756538391","0.25110076049185237","0.9582996866823864","0.9367402297310874","0.9489920041520022","0.945436877479958","0.9473671995113584"
"robust-planet-9","finished","-","","","2024-09-01T02:40:46.000Z","Tesla V100-PCIE-32GB","57218","","x03al74t","","","2024-09-01T18:34:24.000Z","2024-09-01T18:34:24.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","5","11","0.8757320642471313","0.9368620352139556","0.9395511729670596","0.9243936487081676","0.92408517892659","0.9312230089539432","0.3994351616966936","1","0.014091232419013977","1","1","1","1","1","197769","0.8953362703323364","0.32372366120207813","0.9586080645422588","0.9393461331901182","0.9487649415313752","0.9438173089399248","0.9476341120509192"
"revived-durian-8","finished","-","","","2024-09-01T02:40:04.000Z","Tesla V100-PCIE-32GB","57065","","0gci3y6k","","","2024-09-01T18:31:09.000Z","2024-09-01T18:31:09.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","4","11","0.8558070659637451","0.9230507735568092","0.9265147810375658","0.9244070560332048","0.9251778831900384","0.9247876234544046","0.3132841911315715","0.9666666388511658","0.020360757907231648","1","1","1","0.9473684210526316","0.986842105263158","197769","0.894138514995575","0.30622858110468476","0.9565076484314232","0.9368660105980318","0.9490269971735812","0.9439106725525838","0.9465778321889048"
"hopeful-mountain-7","finished","-","","","2024-09-01T02:39:42.000Z","Tesla V100-PCIE-32GB","57201","","uwkkryce","","","2024-09-01T18:33:03.000Z","2024-09-01T18:33:03.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","3","11","0.8780391812324524","0.9483838560653916","0.9340515565656868","0.9356194293234432","0.93041608228144","0.9371177310589904","0.29722806023702947","1","0.005201342205206553","1","1","1","1","1","197769","0.8960296511650085","0.32580944874108275","0.9559902594784628","0.9396684786255955","0.947300166966558","0.948313878080415","0.9478181957877578"
"different-oath-6","finished","-","","","2024-09-01T02:38:59.000Z","Tesla V100-PCIE-32GB","57209","","uog66inh","","","2024-09-01T18:32:28.000Z","2024-09-01T18:32:28.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","2","11","0.8685745000839233","0.9349225572056856","0.9279084972232032","0.9274321664688614","0.9280777747127807","0.9295852489026328","0.30687451548072275","1","0.00036975229158997536","1","1","1","1","1","197769","0.8981624245643616","0.2890206184466963","0.955764530696146","0.9395106365088708","0.951623159576723","0.9454474543770416","0.9480864452896952"
"leafy-totem-5","finished","-","","","2024-09-01T02:37:50.000Z","Tesla V100-PCIE-32GB","57376","","hwfmo10l","","","2024-09-01T18:34:06.000Z","2024-09-01T18:34:06.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","1","11","0.8958110809326172","0.948565573004125","0.9538250159477892","0.9479477870050814","0.946401498925348","0.949184968720586","0.2155239851976042","1","0.0006038505584001541","1","1","1","1","1","197769","0.8936552405357361","0.3025799200995997","0.957629849207372","0.9373885509540872","0.9510802245708864","0.9446412469406156","0.9476849679182404"
"gentle-morning-4","finished","-","","","2024-09-01T02:36:40.000Z","Tesla V100-PCIE-32GB","57082","","aegbo1u4","","","2024-09-01T18:28:02.000Z","2024-09-01T18:28:02.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","0","11","0.8741205930709839","0.9408482973057412","0.935416858167111","0.938463875171798","0.9347724836432189","0.9373753785719672","0.20895449700261773","1","0.0047949343919754025","1","1","1","1","1","197769","0.89324551820755","0.2539830362677425","0.9571214489437472","0.937926681394615","0.9508117432730908","0.9427321279196964","0.9471480003827872"
Loading

0 comments on commit 53b1046

Please sign in to comment.