Tom (Trend over Momentum), is an adaptive optimizer that takes into account of the trend which is observed for the gradients in the loss landscape traversed by the neural network. Tom has an additional smoothing equation which is introduced to address the trend observed during the process of optimization. The smoothing parameter introduced for the trend requires no tuning and can be used with default values.
Please check our paper for more details on Tom.
- PyTorch == 1.9.0
Please clone this repository to your local machine
git clone https://github.com/AnirudhMaiya/Tom.git
You can import the optimizer as follows:
from optimizer.tom import Tom
network = YourNetwork()
opt = Tom(network.parameters())
for input, output in loader:
opt.zero_grad()
loss = YourLossFunction(output, network(input))
loss.backward()
opt.step()