diff --git a/.github/workflows/python_test.yml b/.github/workflows/python_test.yml index d4a0617..93d8f9c 100644 --- a/.github/workflows/python_test.yml +++ b/.github/workflows/python_test.yml @@ -47,4 +47,5 @@ jobs: uses: snickerbockers/submodules-init@v4 - name: Test with unittest run: | + export MIHGNN_UNITTEST_SKIP_GITHUB_ACTION_CRASHERS=True python -m unittest discover tests/ -v diff --git a/README.md b/README.md index 7900882..64706dd 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # MI-HGNN for contact estimation/classification -This repository implements a Morphologically-inspired Heterogeneous graph neural network for estimating contact information on the feet of a quadruped robot. +This repository implements a Morphology-Inspired Heterogeneous Graph Neural Network (MI-HGNN) for estimating contact information on the feet of a quadruped robot. ## Installation To get started, setup a Conda Python environment with Python=3.11: @@ -13,7 +13,7 @@ Then, install the library (and dependencies) with the following command: pip install . ``` -Note, if you have any issues with setup, refer to the `environment_files/README.md` so you can install the exact libraries we used. +Note, if you have any issues with setup, refer to `environment_files/README.md` so you can install the exact libraries we used. ## URDF Download The necessary URDF files are part of git submodules in this repository, so run the following commands to download them: @@ -31,29 +31,25 @@ in the `train_model` function. The model weights will be saved in the following folder, based on the model type and the randomly chosen model name (which is output in the terminal when training begins): ``` -/models/-/ +/models// ``` There will be the six models saved, one with the final model weights, and five with the best validation losses during training. -### LinTzuYaun Contact Dataset Model +### Contact Detection (Classification) Experiment -To train a model from the dataset from [Legged Robot State Estimation using Invariant Kalman Filtering and Learned Contact Events](https://arxiv.org/abs/2106.15713), run the following command within your Conda environment: +To train a model from the dataset from [MorphoSymm-Replication]([https://arxiv.org/abs/2106.15713](https://github.com/lunarlab-gatech/MorphoSymm-Replication/releases/tag/RepBugFixes)), run the following command within your Conda environment. Feel free to edit the model parameters within the file itself: ``` python research/train_classification.py ``` -If you want to customize the model used, the number of layers, or the hidden size, feel free to change the corresponding variables. - -To evaluate this model, edit `evaluator_classification.py` to specify which model to evaluate, its type, and the number of dataset entries to consider. Then, run the following command: +To evaluate a model, edit `evaluator_classification.py` to specify which model to evaluate. Then, run the following command: ``` python research/evaluator_classification.py ``` -The visualization of the predicted and GT values will be found in a file called `model_eval_results.pdf` in the same directory as your model weights. - -### QuadSDK & Real World GRF Model +### GRF (Regression) Model Not Yet Implemented. @@ -61,16 +57,9 @@ Not Yet Implemented. Tutorial not yet written. - -## Changing the model type -Currently, two model types are supported: -- `mlp` -- `heterogeneous_gnn` -To change the model type, please change the `model_type` parameter in the `train.py` and `evaluator.py` files. - ## Editing this repository -If you want to make changes to the source files, feel free to edit them in the `src/grfgnn` folder, and then +If you want to make changes to the source files, feel free to edit them in the `src/mi_hgnn/' folder, and then rebuild the library following the instructions in [#Installation](#installation). ## Paper Replication -To replicate our paper results with the model weights we trained, see `paper/README.md`. \ No newline at end of file +To replicate our paper results with the model weights we trained, see `paper/README.md`. diff --git a/paper/README.md b/paper/README.md index 111eb95..8a11678 100644 --- a/paper/README.md +++ b/paper/README.md @@ -1,41 +1,45 @@ # Paper Replication -This directory provides the model weights for all of our MI-HGNN models (and MLP models) referenced in the paper. Whenever a specific trained model is referenced in this README (for example, `ancient-salad-5`), it will be highlighted as shown, and there will be a folder on [Google Drive](https://drive.google.com/drive/folders/1NS5H_JIXW-pORyQUR15-t3mzf2byG26v?usp=sharing) with its name. In that folder will be its weights after the full 30 epochs of training, which were used to generate the paper results. +This directory provides the model weights for all of our MI-HGNN models referenced in the paper. Whenever a specific trained model is referenced in this README (for example, `ancient-salad-5`), it will be highlighted as shown, and there will be a folder on Georgia Tech's [Dropbox](https://www.dropbox.com/scl/fo/8p165xcfbdfwlcr3jx7tb/ABoxs5BOEXsQnJgFXF_Mjcc?rlkey=znrs7oyu29qsswpd3a5r55zk8&st=53v30ys3&dl=0) with its name. Unless otherwise specified, the model weights used for the paper were those trained the longest (have highest `epoch=` number in their .ckpt file). To find the name of a specific model referenced in the paper or to replicate the results, refer to the following sections below which correspond to paper sections. ## Contact Detection (Classification) Experiment -To replicate the results of these experiments on your own end, input the checkpoint path into the `evaluator_classification.py` file found in the `research` directory of this repository. +Our models trained during this experiment can be found in the table below. For more details, see the `contact_experiment.csv` file in this directory. To evaluate the model metrics on your own end, input the checkpoint path into the `evaluator_classification.py` file found in the `research` directory of this repository: -As of the time of this commit, the main experiment has not been completed yet. +| Number of Layers | Hidden Sizes | Seed | State Accuracy (Test) | Model Name | +| ---------------- | ------------ | ---- |---------------------- | --------------------- | +| 8 | 128 | 0 | 0.874120593070984 | `gentle-morning-4` | +| 8 | 128 | 1 | 0.895811080932617 | `leafy-totem-5` | +| 8 | 128 | 2 | 0.868574500083923 | `different-oath-6` | +| 8 | 128 | 3 | 0.878039181232452 | `hopeful-mountain-7` | +| 8 | 128 | 4 | 0.855807065963745 | `revived-durian-8` | +| 8 | 128 | 5 | 0.875732064247131 | `robust-planet-9` | +| 8 | 128 | 6 | 0.883218884468079 | `super-microwave-10` | +| 8 | 128 | 7 | 0.880922436714172 | `valiant-dawn-11` | + +The baseline models we compared to (ECNN, CNN-aug, CNN) were trained on this release: [MorphoSymm-Replication -> With Bug Fixes](https://github.com/lunarlab-gatech/MorphoSymm-Replication/releases/tag/RepBugFixes). See that repository for information on accessing those model weights, and replicating the Contact Detection Experiment Figure seen in our paper. ### Abalation Study -For this paper, we conducted an abalation study to see how parameter-efficient our model is. In the paper, we give the each trained model's layer number, hidden size, parameter number, and finally, the state accuracy on the test set. The table below relates these parameters to specific trained model names so that you can find the exact checkpoint weights for each model. - -| Number of Layers | Hidden Sizes | Number of Parameters | State Accuracy (Test) | Model Name | -| ---------------- | ------------ | ---------------------| --------------------- | --------------------- | -| 4 | 5 | | | | -| 4 | 10 | | | | -| 4 | 25 | | | | -| 4 | 50 | | | | -| 4 | 128 | | | | -| 8 | 50 | | | | -| 8 | 128 | | | | -| 8 | 256 | | | | -| 12 | 50 | | | | -| 12 | 128 | | | | -| 12 | 256 | | | | - -### Side note on normalization - -In our paper, we mention that we found that our MI-HGNN model performed better without entry-wise normalization. We found this by running the two models below. This wasn't an exhausive experiment, which is why it only deserved a short reference in the paper. However, you can see that for this specific configuration of layer size and hidden size, our MI-HGNN model has a _______ increase in accuracy when disabling the entry-wise normalization used in this experiment. - -| Number of Layers | Hidden Sizes | Normalization | State Accuracy (Test) | Model Name | -| ---------------- | ------------ | --------------| --------------------- | ----------------- | -| 12 | 128 | False | | | -| 12 | 128 | True | | | +We conducted an abalation study to see how parameter-efficient our model is. In the paper, we give the each trained model's layer number, hidden size, parameter number, and finally, the state accuracy on the test set. Here those values are associated with the model's name in the table below. For more details, see the `contact_experiment_ablation.csv` file in this directory. + +| Number of Layers | Hidden Sizes | Model Name | +| ---------------- | ------------ | ---------------------- | +| 4 | 5 | `prime-water-16` | +| 4 | 10 | `driven-shape-17` | +| 4 | 25 | `autumn-terrain-18` | +| 4 | 50 | `comfy-dawn-19` | +| 4 | 128 | `prime-butterfly-20` | +| 4 | 256 | `youthful-galaxy-21` | +| 8 | 50 | `exalted-mountain-22` | +| 8 | 128 | `serene-armadillo-23` | +| 8 | 256 | `playful-durian-12` | +| 12 | 50 | `twilight-armadillo-15`| +| 12 | 128 | `sparkling-music-14` | +| 12 | 256 | `stoic-mountain-13` | + ## Ground Reaction Force Estimation (Regression) Experiment diff --git a/paper/contact_experiment.csv b/paper/contact_experiment.csv new file mode 100644 index 0000000..fe38957 --- /dev/null +++ b/paper/contact_experiment.csv @@ -0,0 +1,9 @@ +"Name","State","Notes","User","Tags","Created","GPU Type","Runtime","Sweep","ID","Group","Job Type","Updated","End Time","Hostname","Description","Commit","GitHub","GPU Count","batch_size","hidden_channels","lr","normalize","num_layers","num_parameters","optimizer","regression","seed","epoch","test_Accuracy","test_F1_Score_Leg_0","test_F1_Score_Leg_1","test_F1_Score_Leg_2","test_F1_Score_Leg_3","test_F1_Score_Leg_Avg","test_CE_loss","train_Accuracy","train_CE_loss","train_F1_Score_Leg_0","train_F1_Score_Leg_1","train_F1_Score_Leg_2","train_F1_Score_Leg_3","train_F1_Score_Leg_Avg","trainer/global_step","val_Accuracy","val_CE_loss","val_F1_Score_Leg_0","val_F1_Score_Leg_1","val_F1_Score_Leg_2","val_F1_Score_Leg_3","val_F1_Score_Leg_Avg" +"valiant-dawn-11","finished","-","","","2024-09-01T02:45:31.000Z","NVIDIA GeForce RTX 4090","7921","","qc5v7hq2","","","2024-09-01T04:57:32.000Z","2024-09-01T04:57:32.000Z","aellnar08106d","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/state-estimation-gnn/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","1","30","128","0.0001","true","8","1585282","adam","false","7","11","0.8809224367141724","0.9493856530476056","0.935029145210738","0.9372308446878636","0.9310457129322596","0.9381728389696168","0.24830123724984204","0.9666666388511658","0.032949131727218625","1","1","0.9090909090909092","1","0.9772727272727272","197769","0.892520546913147","0.34584198461352583","0.9549937648792652","0.9352832408808204","0.9484526033628464","0.9479270099132387","0.9466641547590428" +"super-microwave-10","finished","-","","","2024-09-01T02:42:58.000Z","NVIDIA GeForce RTX 4090 Laptop GPU","12255","","m3h353sj","","","2024-09-01T06:07:13.000Z","2024-09-01T06:07:13.000Z","aellnar08116l","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/state-estimation-gnn/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","1","30","128","0.0001","true","8","1585282","adam","false","6","11","0.8832188844680786","0.9426359403779891","0.9379542360160368","0.9349237797511736","0.9363857798233686","0.937974933992142","0.22593691841093608","1","0.0007627168670296669","1","1","1","1","1","197769","0.8945062756538391","0.25110076049185237","0.9582996866823864","0.9367402297310874","0.9489920041520022","0.945436877479958","0.9473671995113584" +"robust-planet-9","finished","-","","","2024-09-01T02:40:46.000Z","Tesla V100-PCIE-32GB","57218","","x03al74t","","","2024-09-01T18:34:24.000Z","2024-09-01T18:34:24.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","5","11","0.8757320642471313","0.9368620352139556","0.9395511729670596","0.9243936487081676","0.92408517892659","0.9312230089539432","0.3994351616966936","1","0.014091232419013977","1","1","1","1","1","197769","0.8953362703323364","0.32372366120207813","0.9586080645422588","0.9393461331901182","0.9487649415313752","0.9438173089399248","0.9476341120509192" +"revived-durian-8","finished","-","","","2024-09-01T02:40:04.000Z","Tesla V100-PCIE-32GB","57065","","0gci3y6k","","","2024-09-01T18:31:09.000Z","2024-09-01T18:31:09.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","4","11","0.8558070659637451","0.9230507735568092","0.9265147810375658","0.9244070560332048","0.9251778831900384","0.9247876234544046","0.3132841911315715","0.9666666388511658","0.020360757907231648","1","1","1","0.9473684210526316","0.986842105263158","197769","0.894138514995575","0.30622858110468476","0.9565076484314232","0.9368660105980318","0.9490269971735812","0.9439106725525838","0.9465778321889048" +"hopeful-mountain-7","finished","-","","","2024-09-01T02:39:42.000Z","Tesla V100-PCIE-32GB","57201","","uwkkryce","","","2024-09-01T18:33:03.000Z","2024-09-01T18:33:03.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","3","11","0.8780391812324524","0.9483838560653916","0.9340515565656868","0.9356194293234432","0.93041608228144","0.9371177310589904","0.29722806023702947","1","0.005201342205206553","1","1","1","1","1","197769","0.8960296511650085","0.32580944874108275","0.9559902594784628","0.9396684786255955","0.947300166966558","0.948313878080415","0.9478181957877578" +"different-oath-6","finished","-","","","2024-09-01T02:38:59.000Z","Tesla V100-PCIE-32GB","57209","","uog66inh","","","2024-09-01T18:32:28.000Z","2024-09-01T18:32:28.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","2","11","0.8685745000839233","0.9349225572056856","0.9279084972232032","0.9274321664688614","0.9280777747127807","0.9295852489026328","0.30687451548072275","1","0.00036975229158997536","1","1","1","1","1","197769","0.8981624245643616","0.2890206184466963","0.955764530696146","0.9395106365088708","0.951623159576723","0.9454474543770416","0.9480864452896952" +"leafy-totem-5","finished","-","","","2024-09-01T02:37:50.000Z","Tesla V100-PCIE-32GB","57376","","hwfmo10l","","","2024-09-01T18:34:06.000Z","2024-09-01T18:34:06.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","1","11","0.8958110809326172","0.948565573004125","0.9538250159477892","0.9479477870050814","0.946401498925348","0.949184968720586","0.2155239851976042","1","0.0006038505584001541","1","1","1","1","1","197769","0.8936552405357361","0.3025799200995997","0.957629849207372","0.9373885509540872","0.9510802245708864","0.9446412469406156","0.9476849679182404" +"gentle-morning-4","finished","-","","","2024-09-01T02:36:40.000Z","Tesla V100-PCIE-32GB","57082","","aegbo1u4","","","2024-09-01T18:28:02.000Z","2024-09-01T18:28:02.000Z","aelcsml19004g","","b5163bd6d13deec3dfe5f450a51c14137be0bfd4","https://github.com/lunarlab-gatech/Morphology-Informed-HGNN/tree/b5163bd6d13deec3dfe5f450a51c14137be0bfd4","8","30","128","0.0001","true","8","1585282","adam","false","0","11","0.8741205930709839","0.9408482973057412","0.935416858167111","0.938463875171798","0.9347724836432189","0.9373753785719672","0.20895449700261773","1","0.0047949343919754025","1","1","1","1","1","197769","0.89324551820755","0.2539830362677425","0.9571214489437472","0.937926681394615","0.9508117432730908","0.9427321279196964","0.9471480003827872" \ No newline at end of file diff --git a/paper/contact_experiment_ablation.csv b/paper/contact_experiment_ablation.csv new file mode 100644 index 0000000..c0161e0 --- /dev/null +++ b/paper/contact_experiment_ablation.csv @@ -0,0 +1,13 @@ +"Name","State","Notes","User","Tags","Created","GPU Type","Runtime","Sweep","batch_size","hidden_channels","lr","normalize","num_layers","num_parameters","optimizer","regression","seed","epoch","test_Accuracy","test_CE_loss","test_F1_Score_Leg_0","test_F1_Score_Leg_1","test_F1_Score_Leg_2","test_F1_Score_Leg_3","test_F1_Score_Leg_Avg","train_Accuracy","train_CE_loss","train_F1_Score_Leg_0","train_F1_Score_Leg_1","train_F1_Score_Leg_2","train_F1_Score_Leg_3","train_F1_Score_Leg_Avg","trainer/global_step","val_Accuracy","val_CE_loss","val_F1_Score_Leg_0","val_F1_Score_Leg_1","val_F1_Score_Leg_2","val_F1_Score_Leg_3","val_F1_Score_Leg_Avg" +"comfy-dawn-19","finished","-","","","2024-09-01T18:48:30.000Z","Tesla V100-PCIE-32GB","35124","","30","50","0.0001","true","4","206252","adam","false","999","12","0.8534150123596191","0.2134763519747147","0.929184824638887","0.9283932523556798","0.9326889298933924","0.9317773459656672","0.9305110882134068","0.9666666388511658","0.08851965268452962","0.9473684210526316","1","1","0.9523809523809524","0.974937343358396","215748","0.8486147522926331","0.25754772896507705","0.9358792042821762","0.9185583001150824","0.9282934251621686","0.9199955277995177","0.9256816143397364" +"youthful-galaxy-21","finished","-","","","2024-09-01T18:49:13.000Z","Tesla V100-PCIE-32GB","32926","","30","256","0.0001","true","4","3165442","adam","false","999","11","0.858852207660675","0.41298029632668015","0.921647137694906","0.937454355764215","0.9153995305643432","0.9282226724681968","0.9256809241229152","1","0.001223126177986463","1","1","1","1","1","197769","0.8672423958778381","0.3783110435643669","0.9450396953462852","0.920413143029047","0.932141419022308","0.9321393998063892","0.9324334143010072" +"autumn-terrain-18","finished","-","","","2024-09-01T18:48:07.000Z","Tesla V100-PCIE-32GB","35261","","30","25","0.0001","true","4","78127","adam","false","999","12","0.8766559958457947","0.1315367893661957","0.935669499062096","0.9433902233938596","0.940950435884344","0.945707346051751","0.9414293760980126","0.9333333373069764","0.02446378469467163","0.96","0.9565217391304348","1","1","0.9791304347826087","215748","0.8518611788749695","0.1799276879379813","0.9401999805844092","0.9147268448763312","0.925213257797747","0.9243566415388615","0.9261241811993371" +"prime-water-16","finished","-","","","2024-09-01T18:47:41.000Z","Tesla V100-PCIE-32GB","83644","","30","5","0.0001","true","4","11627","adam","false","999","43","0.8543043732643127","0.15391518125003317","0.9291618294947246","0.9147546559466216","0.9257299856779588","0.9196665637542452","0.9223282587183876","0.8333333134651184","0.15197056134541828","1","0.9230769230769232","0.9166666666666666","0.9","0.9349358974358974","773097","0.8329183459281921","0.16214891726027253","0.9219262645581628","0.9112338494920604","0.918009612666101","0.917742623702135","0.9172280876046148" +"driven-shape-17","finished","-","","","2024-09-01T18:47:56.000Z","Tesla V100-PCIE-32GB","40340","","30","10","0.0001","true","4","25252","adam","false","999","15","0.8736161589622498","0.11525020069385816","0.9406982110879316","0.9351934696402338","0.9446415556451704","0.9381137363221858","0.9396617431738804","0.8999999761581421","0.06631178855895996","0.96","0.9523809523809524","0.9473684210526316","0.9565217391304348","0.9540677781410049","269685","0.8355448842048645","0.15830966543348987","0.9266773615174106","0.9145911557821904","0.916287347813076","0.920092012374078","0.9194119693716888" +"prime-butterfly-20","finished","-","","","2024-09-01T18:48:48.000Z","Tesla V100-PCIE-32GB","32930","","30","128","0.0001","true","4","927362","adam","false","999","11","0.8603283762931824","0.25785508210942437","0.9395386976408282","0.933949009216504","0.9214418861857694","0.9301133577976198","0.9312607377101804","1","0.004377327859401703","1","1","1","1","1","197769","0.8645107746124268","0.2967396493260209","0.9428405347800816","0.9207676999359126","0.932531406121162","0.9324021439012506","0.9321354461846016" +"serene-armadillo-23","finished","-","","","2024-09-02T02:26:58.000Z","NVIDIA GeForce RTX 4090 Laptop GPU","11750","","30","128","0.0001","true","8","1585282","adam","false","999","11","0.8795312643051147","0.3193329870347951","0.9354252438707764","0.9491536934883824","0.9291468750344402","0.9379047392978214","0.937907637922855","1","0.008158960441748301","1","1","1","1","1","197769","0.8907240033149719","0.3203825146300207","0.9567497145062968","0.93921748560941","0.9492765299216912","0.9414689700527726","0.9466781750225428" +"exalted-mountain-22","finished","-","","","2024-09-01T22:49:43.000Z","NVIDIA GeForce RTX 4090 Laptop GPU","11836","","30","50","0.0001","true","8","307252","adam","false","999","11","0.8505211472511292","0.31005512734228563","0.928144780244272","0.9190371187063578","0.9254917866558507","0.9133282169352775","0.9215004756354396","1","0.002146148681640625","1","1","1","1","1","197769","0.8907660245895386","0.244682212376945","0.9558234229576008","0.935202398298414","0.9468309575581864","0.945737947641594","0.9458986816139489" +"playful-durian-12","finished","-","","","2024-09-01T13:26:34.000Z","NVIDIA GeForce RTX 4090 Laptop GPU","12930","","30","256","0.0001","true","8","5792002","adam","false","999","11","0.8934270143508911","0.21581461379831257","0.9486826471531222","0.9398924662882934","0.9468395753155932","0.9435691288407496","0.9447459543994396","0.9666666388511658","0.01915186643600464","1","1","0.9473684210526316","1","0.986842105263158","197769","0.8963343501091003","0.4257049562150009","0.9561747860222666","0.9402390438247012","0.9509868579793138","0.9463661409076994","0.9484417071834952" +"sparkling-music-14","finished","-","","","2024-09-01T17:39:43.000Z","NVIDIA GeForce RTX 4090 Laptop GPU","16059","","30","128","0.0001","true","12","2243202","adam","false","999","11","0.8833224773406982","0.3085127340626344","0.936454052766363","0.9456326983484395","0.9331364441819258","0.9409620125337848","0.9390463019576284","1","0.000025788651934514444","1","1","1","1","1","197769","0.8991605639457703","0.2738865827016421","0.9583333333333331","0.9377074462735104","0.9524659312134977","0.9495618998351062","0.9495171526638618" +"twilight-armadillo-15","finished","-","","","2024-09-01T18:45:21.000Z","NVIDIA GeForce RTX 4090","10161","","30","50","0.0001","true","12","408252","adam","false","999","11","0.878862202167511","0.22975379368638196","0.9394501802232322","0.9400805639476334","0.9330187298800116","0.9380604413762714","0.9376524788567872","1","0.0012825908760229743","1","1","1","1","1","197769","0.8920267820358276","0.2273940571122388","0.9519974226804124","0.9385472995527072","0.9479886028879122","0.9479143033092406","0.946611907107568" +"stoic-mountain-13","finished","-","","","2024-09-01T13:33:19.000Z","NVIDIA GeForce RTX 4090","12651","","30","256","0.0001","true","12","8418562","adam","false","999","11","0.9011660218238832","0.2178048762882903","0.9516031457955234","0.9508494779407326","0.9476071893408768","0.9475353160271306","0.9493987822760658","0.9666666388511658","0.0061408405502637224","1","1","0.9714285714285714","1","0.9928571428571428","197769","0.8985826969146729","0.3934961333538206","0.9566612121899454","0.937063409037222","0.9509143407122234","0.9502988618685008","0.9487344559519728" \ No newline at end of file diff --git a/research/evaluator_classification.py b/research/evaluator_classification.py index a1bac5a..73b9bab 100644 --- a/research/evaluator_classification.py +++ b/research/evaluator_classification.py @@ -12,39 +12,46 @@ def main(): # ================================= CHANGE THESE =================================== - path_to_checkpoint = "/home/dbutterfield3/Research/state-estimation-gnn/models/ancient-salad-5/epoch=29-val_CE_loss=0.34249.ckpt" - model_type = 'heterogeneous_gnn' - num_entries_to_eval = 100 + path_to_checkpoint = None + starting_index_for_chart = 10000 + num_to_visualize_in_chart = 1000 # ================================================================================== + # Check that the user filled in the necessary parameters + if path_to_checkpoint is None: + raise ValueError("Please provide a checkpoint path by editing this file!") + # Set parameters history_length = 150 + model_type = 'heterogeneous_gnn' path_to_urdf = Path('urdf_files', 'MiniCheetah', 'miniCheetah.urdf').absolute() # Initialize the Testing datasets air_jumping_gait = linData.LinTzuYaunDataset_air_jumping_gait( - Path(Path('.').parent, 'datasets', 'LinTzuYaun-AJG').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=False) + Path(Path('.').parent, 'datasets', 'LinTzuYaun-AJG').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) concrete_pronking = linData.LinTzuYaunDataset_concrete_pronking( - Path(Path('.').parent, 'datasets', 'LinTzuYaun-CP').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=False) + Path(Path('.').parent, 'datasets', 'LinTzuYaun-CP').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) concrete_right_circle = linData.LinTzuYaunDataset_concrete_right_circle( - Path(Path('.').parent, 'datasets', 'LinTzuYaun-CRC').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=False) + Path(Path('.').parent, 'datasets', 'LinTzuYaun-CRC').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) forest = linData.LinTzuYaunDataset_forest( - Path(Path('.').parent, 'datasets', 'LinTzuYaun-F').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=False) + Path(Path('.').parent, 'datasets', 'LinTzuYaun-F').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) small_pebble = linData.LinTzuYaunDataset_small_pebble( - Path(Path('.').parent, 'datasets', 'LinTzuYaun-SP').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=False) + Path(Path('.').parent, 'datasets', 'LinTzuYaun-SP').absolute(), path_to_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) test_dataset = torch.utils.data.ConcatDataset([air_jumping_gait, concrete_pronking, concrete_right_circle, forest, small_pebble]) # Convert them to subsets test_dataset = torch.utils.data.Subset(test_dataset, np.arange(0, test_dataset.__len__())) # Evaluate with model - pred, labels = evaluate_model(path_to_checkpoint, test_dataset, num_entries_to_eval) - - # Output the corresponding results - metric_acc = torchmetrics.Accuracy(task="multiclass", num_classes=16) - y_pred_16, y_16 = Base_Lightning.classification_conversion_16_class(None, pred, labels) - print("Accuracy: ", metric_acc(torch.argmax(y_pred_16, dim=1), y_16.squeeze())) - visualize_model_outputs_classification(pred, labels, str(path_to_checkpoint) + ".pdf", 100) + pred, labels, acc, f1_leg_0, f1_leg_1, f1_leg_2, f1_leg_3, f1_avg_legs = evaluate_model(path_to_checkpoint, test_dataset) + + # Print the results + print("Model Accuracy: ", acc) + print("F1-Score Leg 0: ", f1_leg_0) + print("F1-Score Leg 1: ", f1_leg_1) + print("F1-Score Leg 2: ", f1_leg_2) + print("F1-Score Leg 3: ", f1_leg_3) + print("F1-Score Legs Avg: ", f1_avg_legs) if __name__ == "__main__": main() diff --git a/research/train_classification.py b/research/train_classification.py index 1fddcaf..22e61cc 100644 --- a/research/train_classification.py +++ b/research/train_classification.py @@ -712,8 +712,8 @@ def main(): # Train the model train_model(train_dataset, val_dataset, test_dataset, normalize, num_layers=num_layers, hidden_size=hidden_size, - logger_project_name="mi_hgnn_class_abalation", batch_size=30, regression=False, lr=0.0001, epochs=30, seed=0, - devices=1) + logger_project_name="mi_hgnn_class_early_stopping", batch_size=30, regression=False, lr=0.0001, epochs=49, + seed=0, devices=1, early_stopping=True) if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/mi_hgnn/lightning_py/gnnLightning.py b/src/mi_hgnn/lightning_py/gnnLightning.py index 7fd592f..54ddef3 100644 --- a/src/mi_hgnn/lightning_py/gnnLightning.py +++ b/src/mi_hgnn/lightning_py/gnnLightning.py @@ -5,6 +5,7 @@ from lightning.pytorch.loggers import WandbLogger from torch_geometric.loader import DataLoader from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks.early_stopping import EarlyStopping from pathlib import Path import names from torch.utils.data import Subset @@ -130,7 +131,8 @@ def calculate_losses_step(self, y: torch.Tensor, y_pred: torch.Tensor): # Calculate 16 class accuracy y_pred_16, y_16 = self.classification_conversion_16_class(y_pred_per_foot_prob_only_1, y) - self.acc = self.metric_acc(torch.argmax(y_pred_16, dim=1), y_16.squeeze()) + y_16 = y_16.squeeze(dim=1) + self.acc = self.metric_acc(torch.argmax(y_pred_16, dim=1), y_16) # Calculate binary class predictions for f1-scores y_pred_2 = torch.reshape(torch.argmax(y_pred_per_foot_prob, dim=1), (batch_size, 4)) @@ -206,23 +208,36 @@ def on_test_epoch_end(self): self.log_losses("test", on_step=False) # ======================= Prediction ======================= + # NOTE: These methods have not been fully tested. Use at + # your own risk. + + def on_predict_start(self): + self.reset_all_metrics() + def predict_step(self, batch, batch_idx): """ Returns the predicted values from the model given a specific batch. Returns: - y (torch.Tensor) - Ground Truth labels per foot (contact labels - for classifiction, GRF labels for regression) - y_pred (torch.Tensor) - Predicted outputs per foot (GRF labels - for regression, probability of stable contact for classification) + y (torch.Tensor) - Ground Truth labels per foot (GRF labels for + regression, 16 class contact labels for classifiction) + y_pred (torch.Tensor) - Predicted outputs (GRF labels per foot + for regression, 16 class predictions for classifications) """ y, y_pred = self.step_helper_function(batch) + self.calculate_losses_step(y, y_pred) + if self.regression: - return y, y_pred + return y, y_pred # GRFs else: y_pred_per_foot, y_pred_per_foot_prob, y_pred_per_foot_prob_only_1 = \ self.classification_calculate_useful_values(y_pred, y_pred.shape[0]) - return y, y_pred_per_foot_prob_only_1 + y_pred_16, y_16 = self.classification_conversion_16_class(y_pred_per_foot_prob_only_1, y) + y_16 = y_16.squeeze(dim=1) + return y_16, torch.argmax(y_pred_16, dim=1) # 16 class + + def on_predict_end(self): + self.calculate_losses_epoch() # ======================= Optimizer ======================= def configure_optimizers(self): @@ -406,14 +421,19 @@ def step_helper_function(self, batch): edge_index_dict=batch.edge_index_dict) # Get the outputs from the foot nodes - y_pred = torch.reshape(out_raw.squeeze(), (batch.batch_size, self.model.out_channels_per_foot * 4)) + batch_size = None + if hasattr(batch, "batch_size"): + batch_size = batch.batch_size + else: + batch_size = 1 + y_pred = torch.reshape(out_raw.squeeze(), (batch_size, self.model.out_channels_per_foot * 4)) # Get the labels - y = torch.reshape(batch.y, (batch.batch_size, 4)) + y = torch.reshape(batch.y, (batch_size, 4)) return y, y_pred -def evaluate_model(path_to_checkpoint: Path, predict_dataset: Subset, num_entries_to_eval: int = 1000): +def evaluate_model(path_to_checkpoint: Path, predict_dataset: Subset): """ Runs the provided model on the corresponding dataset, and returns the predicted values and the ground truth values. @@ -421,6 +441,8 @@ def evaluate_model(path_to_checkpoint: Path, predict_dataset: Subset, num_entrie Returns: pred - Predicted values labels - Ground Truth values + * - Additional arguments that correspond to the metrics tracked + during the evaluation. """ # Set the dtype to be 64 by default @@ -441,43 +463,60 @@ def evaluate_model(path_to_checkpoint: Path, predict_dataset: Subset, num_entrie model = Heterogeneous_GNN_Lightning.load_from_checkpoint(str(path_to_checkpoint)) else: raise ValueError("model_type must be mlp or heterogeneous_gnn.") + model.eval() + model.freeze() # Create a validation dataloader valLoader: DataLoader = DataLoader(predict_dataset, batch_size=100, shuffle=False, num_workers=15) # Predict with the model - pred = torch.zeros((0, 4)) - labels = torch.zeros((0, 4)) + pred = None + labels = None if model_type == 'mlp': - trainer = L.Trainer() - predictions_result = trainer.predict(model, valLoader) - for batch_result in predictions_result: - labels = torch.cat((labels, batch_result[0]), dim=0) - pred = torch.cat((pred, batch_result[1]), dim=0) - + raise NotImplementedError + else: # for 'heterogeneous_gnn' + pred = torch.zeros((0)) + labels = torch.zeros((0)) device = 'cpu' # 'cuda' if torch.cuda.is_available() else model.model = model.model.to(device) with torch.no_grad(): + # Print visual output of prediction step + total_batches = len(valLoader) + batch_num = 0 + print("Prediction: ", batch_num, "/", total_batches, "\r", end="") + + # Predict with the model for batch in valLoader: labels_batch, y_pred = model.step_helper_function(batch) + model.calculate_losses_step(labels_batch, y_pred) - # If classification, convert probability logits to stable contact probabilities + # If classification, convert to 16 class predictions and labels if not model.model.regression: y_pred_per_foot, y_pred_per_foot_prob, y_pred_per_foot_prob_only_1 = \ model.classification_calculate_useful_values(y_pred, y_pred.shape[0]) - pred_batch = y_pred_per_foot_prob_only_1 + y_pred_16, y_16 = model.classification_conversion_16_class(y_pred_per_foot_prob_only_1, labels_batch) + y_16 = y_16.squeeze(dim=1) + pred_batch = torch.argmax(y_pred_16, dim=1) + labels_batch = y_16 else: pred_batch = y_pred # Append to the previously collected data pred = torch.cat((pred, pred_batch), dim=0) labels = torch.cat((labels, labels_batch), dim=0) - if pred.shape[0] >= num_entries_to_eval: - break - return pred[0:num_entries_to_eval], labels[0:num_entries_to_eval] + # Print current status + batch_num += 1 + print("Prediction: ", batch_num, "/", total_batches, "\r", end="") + model.calculate_losses_epoch() + + if not model.regression: + legs_avg_f1 = (model.f1_leg0 + model.f1_leg1 + model.f1_leg2 + model.f1_leg3) / 4.0 + return pred, labels, model.acc, model.f1_leg0, model.f1_leg1, model.f1_leg2, model.f1_leg3, legs_avg_f1 + else: + raise NotImplementedError def train_model( train_dataset: Subset, @@ -491,11 +530,12 @@ def train_model( num_layers: int = 8, optimizer: str = "adam", lr: float = 0.003, - epochs: int = 100, + epochs: int = 30, hidden_size: int = 10, regression: bool = True, seed: int = 0, - devices: int = 1): + devices: int = 1, + early_stopping: bool = False): """ Train a learning model with the input datasets. If 'testing_mode' is enabled, limit the batches and epoch size @@ -538,12 +578,14 @@ def train_model( limit_train_batches = None limit_val_batches = None limit_test_batches = None + limit_predict_batches = None num_workers = 30 persistent_workers = True if testing_mode: limit_train_batches = 10 limit_val_batches = 5 limit_test_batches = 5 + limit_predict_batches = limit_test_batches * batch_size num_workers = 1 persistent_workers = False @@ -652,6 +694,11 @@ def train_model( # Lower precision of operations for faster training torch.set_float32_matmul_precision("medium") + # Setup early stopping mechanism to match MorphoSymm-Replication + callbacks = [checkpoint_callback, last_model_callback] + if early_stopping: + callbacks.append(EarlyStopping(monitor=monitor, patience=10, mode='min')) + # Train the model and test trainer = L.Trainer( default_root_dir=path_to_save, @@ -663,12 +710,13 @@ def train_model( limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, + limit_predict_batches=limit_predict_batches, check_val_every_n_epoch=1, enable_progress_bar=True, logger=wandb_logger, - callbacks=[checkpoint_callback, last_model_callback]) + callbacks=callbacks) trainer.fit(lightning_model, trainLoader, valLoader) - trainer.test(lightning_model, dataloaders=testLoader) + trainer.test(lightning_model, dataloaders=testLoader, verbose=True) # Return the path to the trained checkpoint return path_to_save diff --git a/tests/testGnnLightning.py b/tests/testGnnLightning.py index 9285af8..f2516d7 100644 --- a/tests/testGnnLightning.py +++ b/tests/testGnnLightning.py @@ -1,8 +1,8 @@ import unittest from pathlib import Path - import torchmetrics.classification from mi_hgnn import QuadSDKDataset_A1Speed1_0 +import mi_hgnn.datasets_py.LinTzuYaunDataset as linData from mi_hgnn.datasets_py.LinTzuYaunDataset import LinTzuYaunDataset_asphalt_road from mi_hgnn.lightning_py.gnnLightning import train_model, evaluate_model, Heterogeneous_GNN_Lightning, Base_Lightning from mi_hgnn.visualization import visualize_model_outputs_regression, visualize_model_outputs_classification @@ -11,11 +11,14 @@ import numpy as np import torchmetrics from torch_geometric.loader import DataLoader +import os class TestGnnLightning(unittest.TestCase): """ Test the classes and functions found in the gnnLightning.py file. + + TODO: Write test methods for visualization. """ def setUp(self): @@ -28,7 +31,7 @@ def setUp(self): # Setup a random generator self.rand_gen = torch.Generator().manual_seed(10341885) - # Initalize the datasets + # Initalize the simple datasets path_to_urdf = Path('urdf_files', 'A1', 'a1.urdf').absolute() self.path_to_mc_urdf = Path('urdf_files', 'MiniCheetah', 'miniCheetah.urdf').absolute() @@ -58,54 +61,72 @@ def setUp(self): self.path_to_crc_seq, self.path_to_mc_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', 'heterogeneous_gnn', 3) - + # Put them into an array for easy testing over all of them - self.models = [self.dataset_mlp, self.dataset_hgnn, self.dataset_hgnn_3] - self.class_models = [self.class_dataset_mlp, self.class_dataset_hgnn, self.class_dataset_hgnn_3] - - def test_train_eval_vis_model(self): + self.models = [self.dataset_hgnn, self.dataset_hgnn_3] + self.class_models = [self.class_dataset_hgnn, self.class_dataset_hgnn_3] + + # Create the test dataset for classification + history_length = 150 + model_type = 'heterogeneous_gnn' + air_jumping_gait = linData.LinTzuYaunDataset_air_jumping_gait( + Path(Path('.').parent, 'datasets', 'LinTzuYaun-AJG').absolute(), self.path_to_mc_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) + concrete_pronking = linData.LinTzuYaunDataset_concrete_pronking( + Path(Path('.').parent, 'datasets', 'LinTzuYaun-CP').absolute(), self.path_to_mc_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) + concrete_right_circle = linData.LinTzuYaunDataset_concrete_right_circle( + Path(Path('.').parent, 'datasets', 'LinTzuYaun-CRC').absolute(), self.path_to_mc_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) + forest = linData.LinTzuYaunDataset_forest( + Path(Path('.').parent, 'datasets', 'LinTzuYaun-F').absolute(), self.path_to_mc_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) + small_pebble = linData.LinTzuYaunDataset_small_pebble( + Path(Path('.').parent, 'datasets', 'LinTzuYaun-SP').absolute(), self.path_to_mc_urdf, 'package://yobotics_description/', 'mini-cheetah-gazebo-urdf/yobo_model/yobotics_description', model_type, history_length, normalize=True) + test_dataset = torch.utils.data.ConcatDataset([air_jumping_gait, concrete_pronking, concrete_right_circle, forest, small_pebble]) + self.test_dataset = torch.utils.data.Subset(test_dataset, np.arange(0, test_dataset.__len__())) + + def test_methods_run_without_crashing(self): """ Make sure that the train model function runs and finishes successfully. Also, test that we can evaluate - each model as well without a crash, and that we can - visualize the results. + each model as well without a crash. """ + # TODO: Reimplement Regression functionality for evaluation # For each regression model - for i, model in enumerate(self.models): - train_dataset, val_dataset, test_dataset = random_split( - model, [0.7, 0.2, 0.1], generator=self.rand_gen) - path_to_ckpt_folder = train_model(train_dataset, val_dataset, test_dataset, - normalize=False, - testing_mode=True, disable_logger=True, - epochs=2, seed=1919) - - # Make sure three models were saved (2 for top, 1 for last) - models = sorted(Path('.', path_to_ckpt_folder).glob(("epoch=*"))) - self.assertEqual(len(models), 3) - - try: - # Predict with the model - pred, labels = evaluate_model(models[0], test_dataset, 485) - - # Assert the sizes of the results match - self.assertEqual(pred.shape[0], 485) - self.assertEqual(pred.shape[1], 4) - self.assertEqual(labels.shape[0], 485) - self.assertEqual(labels.shape[1], 4) - - # Try and visualize with the model - visualize_model_outputs_regression(pred, labels) - - except Exception as e: - for path in models: - Path.unlink(path, missing_ok=False) - raise e + # for i, model in enumerate(self.models): + # train_dataset, val_dataset, test_dataset = random_split( + # model, [0.7, 0.2, 0.1], generator=self.rand_gen) + # path_to_ckpt_folder = train_model(train_dataset, val_dataset, test_dataset, + # normalize=False, + # testing_mode=True, disable_logger=True, + # epochs=2, seed=1919) + + # # Make sure three models were saved (2 for top, 1 for last) + # models = sorted(Path('.', path_to_ckpt_folder).glob(("epoch=*"))) + # self.assertEqual(len(models), 3) - for path in models: - Path.unlink(path, missing_ok=False) + # try: + # # Predict with the model (should fail) + # with self.assertRaises(NotImplementedError): + # pred, labels = evaluate_model(models[0], test_dataset) + + # # Assert the sizes of the results match + # self.assertEqual(pred.shape[0], 485) + # self.assertEqual(pred.shape[1], 4) + # self.assertEqual(labels.shape[0], 485) + # self.assertEqual(labels.shape[1], 4) + + # # Try and visualize with the model + # visualize_model_outputs_regression(pred, labels) + + # except Exception as e: + # for path in models: + # Path.unlink(path, missing_ok=False) + # raise e + + # for path in models: + # Path.unlink(path, missing_ok=False) # For each classification model + # TODO: Test for MLP for i, model in enumerate(self.class_models): # Test for classification train_dataset, val_dataset, test_dataset = random_split( @@ -120,17 +141,9 @@ def test_train_eval_vis_model(self): self.assertEqual(len(models), 3) try: - # Predict with the model - pred, labels = evaluate_model(models[0], test_dataset, 234) - - # Assert the sizes of the results match - self.assertEqual(pred.shape[0], 234) - self.assertEqual(pred.shape[1], 4) - self.assertEqual(labels.shape[0], 234) - self.assertEqual(labels.shape[1], 4) + # Evaluate the model + pred, labels, acc, f1_leg_0, f1_leg_1, f1_leg_2, f1_leg_3, f1_avg_legs = evaluate_model(models[0], test_dataset) - # Try to visualize the results - visualize_model_outputs_classification(pred, labels) except Exception as e: for path in models: Path.unlink(path, missing_ok=False) @@ -139,6 +152,34 @@ def test_train_eval_vis_model(self): for path in models: Path.unlink(path, missing_ok=False) + @unittest.skipIf(os.environ.get('MIHGNN_UNITTEST_SKIP_GITHUB_ACTION_CRASHERS') == "True", "Skipping tests that crash GitHub Actions") + def test_evaluate_model(self): + """ + Test that the evaluation method properly calculates + the same values logged during the training process. + + TODO: Test the MLP evaluation as well. + TODO: Test for Regression. + """ + + # Evaluate with model + path_to_checkpoint = Path('tests', 'test_models', 'epoch=10-val_CE_loss=0.30258.ckpt').absolute() + pred, labels, acc, f1_leg_0, f1_leg_1, f1_leg_2, f1_leg_3, f1_avg_legs = evaluate_model(str(path_to_checkpoint), self.test_dataset) + + # Assert that the evaluated metrics match what we calculated using + # the test methods during training + np.testing.assert_almost_equal(acc.item(), 0.89581, 4) + np.testing.assert_almost_equal(f1_leg_0.item(), 0.94857, 4) + np.testing.assert_almost_equal(f1_leg_1.item(), 0.95383, 4) + np.testing.assert_almost_equal(f1_leg_2.item(), 0.94795, 4) + np.testing.assert_almost_equal(f1_leg_3.item(), 0.9464, 4) + np.testing.assert_almost_equal(f1_avg_legs.item(), 0.94918, 4) + + # Assert the pred and labels are correct + metric_acc = torchmetrics.Accuracy(task="multiclass", num_classes=16) + calculated_acc = metric_acc(labels, pred) + np.testing.assert_equal(acc.item(), calculated_acc.item()) + def test_MIHGNN_model_output_assumption(self): """ The code as written assumes that the MIHGNN model output follows diff --git a/tests/test_models/epoch=10-val_CE_loss=0.30258.ckpt b/tests/test_models/epoch=10-val_CE_loss=0.30258.ckpt new file mode 100644 index 0000000..34878ef Binary files /dev/null and b/tests/test_models/epoch=10-val_CE_loss=0.30258.ckpt differ