Skip to content

Commit

Permalink
more VPG and DQN tweaks (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioJerez committed Oct 16, 2023
1 parent 75e3695 commit 1171f2a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ namespace ndCarpole_0
,m_timer(ndGetTimeInMicroseconds())
,m_maxGain(ndFloat32(- 1.0e10f))
,m_maxFrames(5000)
,m_stopTraining(2000000)
,m_stopTraining(4000000)
,m_modelIsTrained(false)
{
SetName("cartpoleVPG.dnn");
m_outFile = fopen("cartpole-VPG.csv", "wb");
fprintf(m_outFile, "VPG\n");
}
#else
ndCartpoleAgentTrainer(ndBrainAgentDQN_Trainer<m_stateSize, m_actionsSize>::HyperParameters hyperParameters)
Expand All @@ -116,11 +117,10 @@ namespace ndCarpole_0
,m_maxGain(ndFloat32(-1.0e10f))
,m_maxFrames(5000)
,m_stopTraining(2000000)
,m_averageQvalue()
,m_averageFramesPerEpisodes()
{
SetName("cartpoleDQN.dnn");
m_outFile = fopen("cartpole-DQN.csv", "wb");
fprintf(m_outFile, "DQN\n");
}
#endif

Expand Down Expand Up @@ -208,6 +208,7 @@ namespace ndCarpole_0
ndExpandTraceMessage("training complete\n\n");
ndUnsigned64 timer = ndGetTimeInMicroseconds() - m_timer;
ndExpandTraceMessage("training time: %f\n", ndFloat32(ndFloat64(timer) * ndFloat32(1.0e-6f)));
m_modelIsTrained = true;
if (m_outFile)
{
fclose(m_outFile);
Expand Down Expand Up @@ -443,7 +444,7 @@ namespace ndCarpole_0

BuildModel(model, scene, location);

//scene->SetAcceleratedUpdate();
scene->SetAcceleratedUpdate();
return model;
}
#endif
Expand Down
21 changes: 16 additions & 5 deletions newton-4.00/sdk/dBrain/ndBrainAgentDQN_Trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ class ndBrainAgentDQN_Trainer: public ndBrainAgent, public ndBrainThreadPool
ndInt32 m_hiddenLayersNumberOfNeurons;
};

//ndBrainAgentDQN_Trainer(const HyperParameters& hyperParameters, const ndBrain& actor);
ndBrainAgentDQN_Trainer(const HyperParameters& hyperParameters);
virtual ~ndBrainAgentDQN_Trainer();

Expand Down Expand Up @@ -113,7 +112,6 @@ class ndBrainAgentDQN_Trainer: public ndBrainAgent, public ndBrainThreadPool

ndBrainFloat m_gamma;
ndBrainFloat m_learnRate;
ndBrainFloat m_currentQValue;
ndBrainFloat m_explorationProbability;
ndBrainFloat m_minExplorationProbability;
ndBrainFloat m_explorationProbabilityAnnelining;
Expand All @@ -122,6 +120,8 @@ class ndBrainAgentDQN_Trainer: public ndBrainAgent, public ndBrainThreadPool
ndInt32 m_eposideCount;
ndInt32 m_bashBufferSize;
ndInt32 m_targetUpdatePeriod;
ndMovingAverage<1024> m_averageQvalue;
ndMovingAverage<64> m_averageFramesPerEpisodes;
bool m_collectingSamples;
};

Expand All @@ -135,7 +135,6 @@ ndBrainAgentDQN_Trainer<statesDim, actionDim>::ndBrainAgentDQN_Trainer(const Hyp
,m_replayBuffer()
,m_gamma(hyperParameters.m_discountFactor)
,m_learnRate(hyperParameters.m_learnRate)
,m_currentQValue(ndBrainFloat(0.0f))
,m_explorationProbability(ndBrainFloat(1.0f))
,m_minExplorationProbability(hyperParameters.m_exploreMinProbability)
,m_explorationProbabilityAnnelining(hyperParameters.m_exploreAnnelining)
Expand All @@ -144,6 +143,8 @@ ndBrainAgentDQN_Trainer<statesDim, actionDim>::ndBrainAgentDQN_Trainer(const Hyp
,m_eposideCount(0)
,m_bashBufferSize(hyperParameters.m_bashBufferSize)
,m_targetUpdatePeriod(hyperParameters.m_targetUpdatePeriod)
,m_averageQvalue()
,m_averageFramesPerEpisodes()
,m_collectingSamples(true)
{
// build neural net
Expand Down Expand Up @@ -219,7 +220,9 @@ ndInt32 ndBrainAgentDQN_Trainer<statesDim, actionDim>::GetFramesCount() const
template<ndInt32 statesDim, ndInt32 actionDim>
ndBrainFloat ndBrainAgentDQN_Trainer<statesDim, actionDim>::GetCurrentValue() const
{
return m_currentQValue;
ndAssert(0);
//return m_currentQValue;
return 0;
}

template<ndInt32 statesDim, ndInt32 actionDim>
Expand Down Expand Up @@ -409,7 +412,10 @@ void ndBrainAgentDQN_Trainer<statesDim, actionDim>::AddExploration(ndBrainFloat*
action = qActionValues.ArgMax();
}

m_currentQValue = qActionValues[action];
if (!IsSampling())
{
m_averageQvalue.Update(qActionValues[action]);
}
m_currentTransition.m_action[0] = ndBrainFloat(action);
}

Expand Down Expand Up @@ -454,6 +460,11 @@ void ndBrainAgentDQN_Trainer<statesDim, actionDim>::OptimizeStep()
{
ndExpandTraceMessage("collecting samples: frame %d out of %d, episode %d \n", m_frameCount, m_replayBuffer.GetCapacity(), m_eposideCount);
}

if (!IsSampling())
{
m_averageFramesPerEpisodes.Update(ndBrainFloat(m_framesAlive));
}
m_eposideCount++;
m_framesAlive = 0;
}
Expand Down
2 changes: 1 addition & 1 deletion newton-4.00/sdk/dBrain/ndBrainAgentDiscreteVPG_Trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void ndBrainAgentDiscreteVPG_Trainer<statesDim, actionDim>::CalcucateRewards()
{
m_rewards[i] = m_trajectory[i].m_reward + m_gamma * m_rewards[i + 1];
}
//m_currentQValue = m_rewards[0];

m_averageQvalue.Update(m_rewards[0]);
m_averageFramesPerEpisodes.Update(ndBrainFloat(steps));
m_rewards.GaussianNormalize();
Expand Down

0 comments on commit 1171f2a

Please sign in to comment.