Skip to content

Commit

Permalink
rename some classes for more consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
JulioJerez committed Oct 20, 2023
1 parent 81b5880 commit 9f42ac3
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 108 deletions.
52 changes: 26 additions & 26 deletions newton-4.00/applications/ndSandbox/demos/ndCartpoleContinue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,25 @@ namespace ndCarpole_1
m_stateSize
};

class ndCartpole: public ndModelArticulation
class ndRobot: public ndModelArticulation
{
public:

#ifdef D_USE_VANILLA_POLICY_GRAD
class ndCartpoleAgent : public ndBrainAgentContinueVPG<m_stateSize, m_actionsSize>
class ndController : public ndBrainAgentContinueVPG<m_stateSize, m_actionsSize>
#else
class ndCartpoleAgent : public ndBrainAgentDDPG<m_stateSize, m_actionsSize>
class ndController : public ndBrainAgentDDPG<m_stateSize, m_actionsSize>
#endif
{
public:
#ifdef D_USE_VANILLA_POLICY_GRAD
ndCartpoleAgent(ndSharedPtr<ndBrain>& actor)
ndController(ndSharedPtr<ndBrain>& actor)
:ndBrainAgentContinueVPG<m_stateSize, m_actionsSize>(actor)
,m_model(nullptr)
{
}
#else
ndCartpoleAgent(ndSharedPtr<ndBrain>& actor)
ndController(ndSharedPtr<ndBrain>& actor)
:ndBrainAgentDDPG<m_stateSize, m_actionsSize>(actor)
,m_model(nullptr)
{
Expand All @@ -84,19 +84,19 @@ namespace ndCarpole_1
m_model->ApplyActions(actions);
}

ndCartpole* m_model;
ndRobot* m_model;
};


#ifdef D_USE_VANILLA_POLICY_GRAD
class ndCartpoleAgentTrainer : public ndBrainAgentContinueVPG_Trainer<m_stateSize, m_actionsSize>
class ndControllerTrainer : public ndBrainAgentContinueVPG_Trainer<m_stateSize, m_actionsSize>
#else
class ndCartpoleAgentTrainer : public ndBrainAgentDDPG_Trainer<m_stateSize, m_actionsSize>
class ndControllerTrainer : public ndBrainAgentDDPG_Trainer<m_stateSize, m_actionsSize>
#endif
{
public:
#ifdef D_USE_VANILLA_POLICY_GRAD
ndCartpoleAgentTrainer(const HyperParameters& hyperParameters)
ndControllerTrainer(const HyperParameters& hyperParameters)
:ndBrainAgentContinueVPG_Trainer<m_stateSize, m_actionsSize>(hyperParameters)
,m_bestActor(m_actor)
,m_model(nullptr)
Expand All @@ -110,7 +110,7 @@ namespace ndCarpole_1
fprintf(m_outFile, "vpg\n");
}
#else
ndCartpoleAgentTrainer(const HyperParameters& hyperParameters)
ndControllerTrainer(const HyperParameters& hyperParameters)
:ndBrainAgentDDPG_Trainer<m_stateSize, m_actionsSize>(hyperParameters)
,m_bestActor(m_actor)
,m_model(nullptr)
Expand All @@ -125,7 +125,7 @@ namespace ndCarpole_1
}
#endif

~ndCartpoleAgentTrainer()
~ndControllerTrainer()
{
if (m_outFile)
{
Expand Down Expand Up @@ -237,15 +237,15 @@ namespace ndCarpole_1

FILE* m_outFile;
ndBrain m_bestActor;
ndCartpole* m_model;
ndRobot* m_model;
ndUnsigned64 m_timer;
ndFloat32 m_maxGain;
ndInt32 m_maxFrames;
ndInt32 m_stopTraining;
bool m_modelIsTrained;
};

ndCartpole(const ndSharedPtr<ndBrainAgent>& agent)
ndRobot(const ndSharedPtr<ndBrainAgent>& agent)
:ndModelArticulation()
,m_cartMatrix(ndGetIdentityMatrix())
,m_poleMatrix(ndGetIdentityMatrix())
Expand Down Expand Up @@ -346,14 +346,14 @@ namespace ndCarpole_1
#ifdef D_TRAIN_AGENT
if (m_agent->IsTrainer())
{
ndCartpoleAgentTrainer* const agent = (ndCartpoleAgentTrainer*)(*m_agent);
ndControllerTrainer* const agent = (ndControllerTrainer*)(*m_agent);
if (agent->m_modelIsTrained)
{
char fileName[1024];
ndGetWorkingFileName(agent->GetName().GetStr(), fileName);
ndSharedPtr<ndBrain> actor(ndBrainLoad::Load(fileName));
m_agent = ndSharedPtr<ndBrainAgent>(new ndCartpole::ndCartpoleAgent(actor));
((ndCartpole::ndCartpoleAgent*)*m_agent)->m_model = this;
m_agent = ndSharedPtr<ndBrainAgent>(new ndRobot::ndController(actor));
((ndRobot::ndController*)*m_agent)->m_model = this;
//ResetModel();
((ndPhysicsWorld*)m_world)->NormalUpdates();
}
Expand All @@ -380,7 +380,7 @@ namespace ndCarpole_1
ndSharedPtr<ndBrainAgent> m_agent;
};

void BuildModel(ndCartpole* const model, ndDemoEntityManager* const scene, const ndMatrix& location)
void BuildModel(ndRobot* const model, ndDemoEntityManager* const scene, const ndMatrix& location)
{
ndFloat32 xSize = 0.25f;
ndFloat32 ySize = 0.125f;
Expand Down Expand Up @@ -429,7 +429,7 @@ namespace ndCarpole_1
}

#ifdef D_TRAIN_AGENT
ndCartpole* CreateTrainModel(ndDemoEntityManager* const scene, const ndMatrix& location)
ndRobot* CreateTrainModel(ndDemoEntityManager* const scene, const ndMatrix& location)
{
// add a reinforcement learning controller
#ifdef D_USE_VANILLA_POLICY_GRAD
Expand All @@ -442,10 +442,10 @@ namespace ndCarpole_1

//hyperParameters.m_threadsCount = 1;
hyperParameters.m_discountFactor = ndBrainFloat(0.995f);
ndSharedPtr<ndBrainAgent> agent(new ndCartpole::ndCartpoleAgentTrainer(hyperParameters));
ndSharedPtr<ndBrainAgent> agent(new ndRobot::ndControllerTrainer(hyperParameters));

ndCartpole* const model = new ndCartpole(agent);
ndCartpole::ndCartpoleAgentTrainer* const trainer = (ndCartpole::ndCartpoleAgentTrainer*)*agent;
ndRobot* const model = new ndRobot(agent);
ndRobot::ndControllerTrainer* const trainer = (ndRobot::ndControllerTrainer*)*agent;
trainer->m_model = model;
trainer->SetName(CONTROLLER_NAME);

Expand All @@ -459,16 +459,16 @@ namespace ndCarpole_1
ndModelArticulation* CreateModel(ndDemoEntityManager* const scene, const ndMatrix& location)
{
#ifdef D_TRAIN_AGENT
ndCartpole* const model = CreateTrainModel(scene, location);
ndRobot* const model = CreateTrainModel(scene, location);
#else
char fileName[1024];
ndGetWorkingFileName(CONTROLLER_NAME, fileName);

ndSharedPtr<ndBrain> actor(ndBrainLoad::Load(fileName));
ndSharedPtr<ndBrainAgent> agent(new ndCartpole::ndCartpoleAgent(actor));
ndSharedPtr<ndBrainAgent> agent(new ndRobot::ndController(actor));

ndCartpole* const model = new ndCartpole(agent);
((ndCartpole::ndCartpoleAgent*)*agent)->m_model = model;
ndRobot* const model = new ndRobot(agent);
((ndRobot::ndController*)*agent)->m_model = model;

BuildModel(model, scene, location);
#endif
Expand All @@ -478,7 +478,7 @@ namespace ndCarpole_1

using namespace ndCarpole_1;

void ndCartpoleContinuePlayer(ndDemoEntityManager* const scene)
void ndCartpoleContinue(ndDemoEntityManager* const scene)
{
BuildFlatPlane(scene, true);

Expand Down
54 changes: 27 additions & 27 deletions newton-4.00/applications/ndSandbox/demos/ndCartpoleDiscrete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,25 @@ namespace ndCarpole_0
m_stateSize
};

class ndCartpole: public ndModelArticulation
class ndRobot: public ndModelArticulation
{
public:

#ifdef D_USE_VANILLA_POLICY_GRAD
class ndCartpoleAgent : public ndBrainAgentDiscreteVPG<m_stateSize, m_actionsSize>
class ndController : public ndBrainAgentDiscreteVPG<m_stateSize, m_actionsSize>
#else
class ndCartpoleAgent : public ndBrainAgentDQN<m_stateSize, m_actionsSize>
class ndController : public ndBrainAgentDQN<m_stateSize, m_actionsSize>
#endif
{
public:
#ifdef D_USE_VANILLA_POLICY_GRAD
ndCartpoleAgent(ndSharedPtr<ndBrain>& actor)
ndController(ndSharedPtr<ndBrain>& actor)
:ndBrainAgentDiscreteVPG<m_stateSize, m_actionsSize>(actor)
,m_model(nullptr)
{
}
#else
ndCartpoleAgent(ndSharedPtr<ndBrain>& actor)
ndController(ndSharedPtr<ndBrain>& actor)
:ndBrainAgentDQN<m_stateSize, m_actionsSize>(actor)
,m_model(nullptr)
{
Expand All @@ -86,23 +86,23 @@ namespace ndCarpole_0
m_model->ApplyActions(actions);
}

void SetModel(ndCartpole* const model)
void SetModel(ndRobot* const model)
{
m_model = model;
}

ndCartpole* m_model;
ndRobot* m_model;
};

#ifdef D_USE_VANILLA_POLICY_GRAD
class ndCartpoleAgentTrainer : public ndBrainAgentDiscreteVPG_Trainer<m_stateSize, m_actionsSize>
class ndControllerTrainer : public ndBrainAgentDiscreteVPG_Trainer<m_stateSize, m_actionsSize>
#else
class ndCartpoleAgentTrainer : public ndBrainAgentDQN_Trainer<m_stateSize, m_actionsSize>
class ndControllerTrainer : public ndBrainAgentDQN_Trainer<m_stateSize, m_actionsSize>
#endif
{
public:
#ifdef D_USE_VANILLA_POLICY_GRAD
ndCartpoleAgentTrainer(ndBrainAgentDiscreteVPG_Trainer<m_stateSize, m_actionsSize>::HyperParameters hyperParameters)
ndControllerTrainer(ndBrainAgentDiscreteVPG_Trainer<m_stateSize, m_actionsSize>::HyperParameters hyperParameters)
:ndBrainAgentDiscreteVPG_Trainer<m_stateSize, m_actionsSize>(hyperParameters)
,m_bestActor(m_actor)
,m_model(nullptr)
Expand All @@ -116,7 +116,7 @@ namespace ndCarpole_0
fprintf(m_outFile, "VPG\n");
}
#else
ndCartpoleAgentTrainer(ndBrainAgentDQN_Trainer<m_stateSize, m_actionsSize>::HyperParameters hyperParameters)
ndControllerTrainer(ndBrainAgentDQN_Trainer<m_stateSize, m_actionsSize>::HyperParameters hyperParameters)
:ndBrainAgentDQN_Trainer<m_stateSize, m_actionsSize>(hyperParameters)
,m_bestActor(m_actor)
,m_model(nullptr)
Expand All @@ -130,7 +130,7 @@ namespace ndCarpole_0
}
#endif

~ndCartpoleAgentTrainer()
~ndControllerTrainer()
{
if (m_outFile)
{
Expand Down Expand Up @@ -231,15 +231,15 @@ namespace ndCarpole_0

FILE* m_outFile;
ndBrain m_bestActor;
ndCartpole* m_model;
ndRobot* m_model;
ndUnsigned64 m_timer;
ndFloat32 m_maxGain;
ndInt32 m_maxFrames;
ndInt32 m_stopTraining;
bool m_modelIsTrained;
};

ndCartpole(const ndSharedPtr<ndBrainAgent>& agent)
ndRobot(const ndSharedPtr<ndBrainAgent>& agent)
:ndModelArticulation()
,m_cartMatrix(ndGetIdentityMatrix())
,m_poleMatrix(ndGetIdentityMatrix())
Expand Down Expand Up @@ -347,14 +347,14 @@ namespace ndCarpole_0
#ifdef D_TRAIN_AGENT
if (m_agent->IsTrainer())
{
ndCartpoleAgentTrainer* const agent = (ndCartpoleAgentTrainer*)(*m_agent);
ndControllerTrainer* const agent = (ndControllerTrainer*)(*m_agent);
if (agent->m_modelIsTrained)
{
char fileName[1024];
ndGetWorkingFileName(agent->GetName().GetStr(), fileName);
ndSharedPtr<ndBrain> actor(ndBrainLoad::Load(fileName));
m_agent = ndSharedPtr<ndBrainAgent>(new ndCartpole::ndCartpoleAgent(actor));
((ndCartpole::ndCartpoleAgent*)*m_agent)->SetModel(this);
m_agent = ndSharedPtr<ndBrainAgent>(new ndRobot::ndController(actor));
((ndRobot::ndController*)*m_agent)->SetModel(this);
//ResetModel();
((ndPhysicsWorld*)m_world)->NormalUpdates();
}
Expand Down Expand Up @@ -382,7 +382,7 @@ namespace ndCarpole_0
ndSharedPtr<ndBrainAgent> m_agent;
};

void BuildModel(ndCartpole* const model, ndDemoEntityManager* const scene, const ndMatrix& location)
void BuildModel(ndRobot* const model, ndDemoEntityManager* const scene, const ndMatrix& location)
{
ndFloat32 xSize = 0.25f;
ndFloat32 ySize = 0.125f;
Expand Down Expand Up @@ -431,7 +431,7 @@ namespace ndCarpole_0
}

#ifdef D_TRAIN_AGENT
ndCartpole* CreateTrainModel(ndDemoEntityManager* const scene, const ndMatrix& location)
ndRobot* CreateTrainModel(ndDemoEntityManager* const scene, const ndMatrix& location)
{
#ifdef D_USE_VANILLA_POLICY_GRAD
ndBrainAgentDiscreteVPG_Trainer<m_stateSize, m_actionsSize>::HyperParameters hyperParameters;
Expand All @@ -442,10 +442,10 @@ namespace ndCarpole_0
#endif
//hyperParameters.m_threadsCount = 1;

ndSharedPtr<ndBrainAgent> agent(new ndCartpole::ndCartpoleAgentTrainer(hyperParameters));
ndSharedPtr<ndBrainAgent> agent(new ndRobot::ndControllerTrainer(hyperParameters));

ndCartpole* const model = new ndCartpole(agent);
ndCartpole::ndCartpoleAgentTrainer* const trainer = (ndCartpole::ndCartpoleAgentTrainer*)*agent;
ndRobot* const model = new ndRobot(agent);
ndRobot::ndControllerTrainer* const trainer = (ndRobot::ndControllerTrainer*)*agent;
trainer->m_model = model;
trainer->SetName(CONTROLLER_NAME);

Expand All @@ -459,16 +459,16 @@ namespace ndCarpole_0
ndModelArticulation* CreateModel(ndDemoEntityManager* const scene, const ndMatrix& location)
{
#ifdef D_TRAIN_AGENT
ndCartpole* const model = CreateTrainModel(scene, location);
ndRobot* const model = CreateTrainModel(scene, location);
#else
char fileName[1024];
ndGetWorkingFileName(CONTROLLER_NAME, fileName);

ndSharedPtr<ndBrain> actor(ndBrainLoad::Load(fileName));
ndSharedPtr<ndBrainAgent> agent(new ndCartpole::ndCartpoleAgent(actor));
ndSharedPtr<ndBrainAgent> agent(new ndRobot::ndController(actor));

ndCartpole* const model = new ndCartpole(agent);
((ndCartpole::ndCartpoleAgent*)*agent)->m_model = model;
ndRobot* const model = new ndRobot(agent);
((ndRobot::ndController*)*agent)->m_model = model;

BuildModel(model, scene, location);
#endif
Expand All @@ -478,7 +478,7 @@ namespace ndCarpole_0

using namespace ndCarpole_0;

void ndCartpoleDiscretePlayer(ndDemoEntityManager* const scene)
void ndCartpoleDiscrete(ndDemoEntityManager* const scene)
{
BuildFlatPlane(scene, true);

Expand Down
Loading

0 comments on commit 9f42ac3

Please sign in to comment.