Skip to content

Commit

Permalink
Delete agents by agent id (#324)
Browse files Browse the repository at this point in the history
* Init

* Delete agents

* Set default deleted agents idx to -1

* Reset deleted agents in setmaps
  • Loading branch information
aaravpandya authored Jan 20, 2025
1 parent 114599a commit a713a83
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 2 deletions.
15 changes: 14 additions & 1 deletion src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,20 @@ namespace gpudrive
.def("expert_trajectory_tensor", &Manager::expertTrajectoryTensor)
.def("set_maps", &Manager::setMaps)
.def("world_means_tensor", &Manager::worldMeansTensor)
.def("metadata_tensor", &Manager::metadataTensor);
.def("metadata_tensor", &Manager::metadataTensor)
.def("deleteAgents", [](Manager &self, nb::dict py_agents_to_delete) {
std::unordered_map<int32_t, std::vector<int32_t>> agents_to_delete;

// Convert Python dict to C++ unordered_map
for (auto item : py_agents_to_delete) {
int32_t key = nb::cast<int32_t>(item.first);
std::vector<int32_t> value = nb::cast<std::vector<int32_t>>(item.second);
agents_to_delete[key] = value;
}

self.deleteAgents(agents_to_delete);
})
.def("deleted_agents_tensor", &Manager::deletedAgentsTensor);
}

}
9 changes: 9 additions & 0 deletions src/level_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,15 @@ static inline bool shouldAgentBeCreated(Engine &ctx, const MapObject &agentInit)
return false;
}

auto& deletedAgents = ctx.singleton<DeletedAgents>().deletedAgents;
for (CountT i = 0; i < consts::kMaxAgentCount; i++)
{
if(deletedAgents[i] == agentInit.id)
{
return false;
}
}

return true;
}

Expand Down
74 changes: 73 additions & 1 deletion src/mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,11 @@ void Manager::setMaps(const std::vector<std::string> &maps)

auto resetMapPtr = (ResetMap *)gpu_exec.getExported((uint32_t)ExportID::ResetMap) + world_idx;
REQ_CUDA(cudaMemcpy(resetMapPtr, &resetmap, sizeof(ResetMap), cudaMemcpyHostToDevice));

// reset agents to delete
auto agentsToDeleteDevicePtr = (int32_t *)gpu_exec.getExported((uint32_t)ExportID::DeletedAgents);
int32_t *agentsToDeletePtr = agentsToDeleteDevicePtr + world_idx * consts::kMaxAgentCount;
REQ_CUDA(cudaMemset(agentsToDeletePtr, -1, consts::kMaxAgentCount * sizeof(int32_t)));
}

#else
Expand All @@ -634,15 +639,82 @@ void Manager::setMaps(const std::vector<std::string> &maps)

auto resetMapPtr = (ResetMap *)cpu_exec.getExported((uint32_t)ExportID::ResetMap) + world_idx;
memcpy(resetMapPtr, &resetmap, sizeof(ResetMap));

// reset agents to delete
auto agentsToDeleteDevicePtr = (int32_t *)cpu_exec.getExported((uint32_t)ExportID::DeletedAgents);
int32_t *agentsToDeletePtr = agentsToDeleteDevicePtr + world_idx * consts::kMaxAgentCount;
memset(agentsToDeletePtr, -1, consts::kMaxAgentCount * sizeof(int32_t));
}
}

// Vector of range on integers from 0 to the number of worlds
std::vector<int32_t> worldIndices(maps.size());
std::vector<int32_t> worldIndices(impl_->cfg.scenes.size());
std::iota(worldIndices.begin(), worldIndices.end(), 0);
reset(worldIndices);
}

Tensor Manager::deletedAgentsTensor() const
{
return impl_->exportTensor(ExportID::DeletedAgents, TensorElementType::Int32,
{
impl_->numWorlds,
consts::kMaxAgentCount,
});
}

void Manager::deleteAgents(const std::unordered_map<int32_t, std::vector<int32_t>> &agentsToDelete)
{

ResetMap resetmap{
1,
};

if (impl_->cfg.execMode == madrona::ExecMode::CUDA)
{
#ifdef MADRONA_CUDA_SUPPORT
auto &gpu_exec = static_cast<CUDAImpl *>(impl_.get())->gpuExec;
auto agentsToDeleteDevicePtr = (int32_t *)gpu_exec.getExported((uint32_t)ExportID::DeletedAgents);
for (const auto &[worldIdx, agents] : agentsToDelete)
{
assert(worldIdx < impl_->cfg.scenes.size());
assert(agents.size() <= consts::kMaxAgentCount);
int32_t *agentsToDeletePtr = agentsToDeleteDevicePtr + worldIdx * consts::kMaxAgentCount;
for (size_t i = 0; i < agents.size(); i++)
{
REQ_CUDA(cudaMemcpy(agentsToDeletePtr + i, &agents[i], sizeof(int32_t), cudaMemcpyHostToDevice));
}
auto resetMapPtr = (ResetMap *)gpu_exec.getExported((uint32_t)ExportID::ResetMap) + worldIdx;
REQ_CUDA(cudaMemcpy(resetMapPtr, &resetmap, sizeof(ResetMap), cudaMemcpyHostToDevice));
}
#else
// Handle the case where CUDA support is not available
FATAL("Madrona was not compiled with CUDA support");
#endif
}
else
{
auto &cpu_exec = static_cast<CPUImpl *>(impl_.get())->cpuExec;
auto agentsToDeleteDevicePtr = (int32_t *)cpu_exec.getExported((uint32_t)ExportID::DeletedAgents);
for (const auto &[worldIdx, agents] : agentsToDelete)
{
assert(worldIdx < impl_->cfg.scenes.size());
assert(agents.size() <= consts::kMaxAgentCount);
int32_t *agentsToDeletePtr = agentsToDeleteDevicePtr + worldIdx * consts::kMaxAgentCount;
for (size_t i = 0; i < agents.size(); i++)
{
memcpy(agentsToDeletePtr + i, &agents[i], sizeof(int32_t));
}
auto resetMapPtr = (ResetMap *)cpu_exec.getExported((uint32_t)ExportID::ResetMap) + worldIdx;
memcpy(resetMapPtr, &resetmap, sizeof(ResetMap));
}
}

std::vector<int32_t> worldIndices(impl_->cfg.scenes.size());
std::iota(worldIndices.begin(), worldIndices.end(), 0);
reset(worldIndices);
}


Tensor Manager::actionTensor() const
{
return impl_->exportTensor(ExportID::Action, TensorElementType::Float32,
Expand Down
4 changes: 4 additions & 0 deletions src/mgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Manager {
MGR_EXPORT madrona::py::Tensor expertTrajectoryTensor() const;
MGR_EXPORT madrona::py::Tensor worldMeansTensor() const;
MGR_EXPORT madrona::py::Tensor metadataTensor() const;
MGR_EXPORT madrona::py::Tensor deletedAgentsTensor() const;
madrona::py::Tensor rgbTensor() const;
madrona::py::Tensor depthTensor() const;
// These functions are used by the viewer to control the simulation
Expand All @@ -78,6 +79,9 @@ class Manager {
float acceleration, float steering,
float headAngle);
MGR_EXPORT void setMaps(const std::vector<std::string> &maps);

MGR_EXPORT void deleteAgents(const std::unordered_map<int32_t, std::vector<int32_t>> &agentsToDelete);

// TODO: remove parameters
MGR_EXPORT std::vector<Shape>
getShapeTensorFromDeviceMemory();
Expand Down
7 changes: 7 additions & 0 deletions src/sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
registry.registerSingleton<Map>();
registry.registerSingleton<ResetMap>();
registry.registerSingleton<WorldMeans>();
registry.registerSingleton<DeletedAgents>();

registry.registerArchetype<Agent>();
registry.registerArchetype<PhysicsEntity>();
Expand All @@ -74,6 +75,7 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
registry.exportSingleton<Map>((uint32_t)ExportID::Map);
registry.exportSingleton<ResetMap>((uint32_t)ExportID::ResetMap);
registry.exportSingleton<WorldMeans>((uint32_t)ExportID::WorldMeans);
registry.exportSingleton<DeletedAgents>((uint32_t)ExportID::DeletedAgents);

registry.exportColumn<AgentInterface, Action>(
(uint32_t)ExportID::Action);
Expand Down Expand Up @@ -878,6 +880,11 @@ Sim::Sim(Engine &ctx,

auto& map = ctx.singleton<Map>();
map = *(init.map);

auto& deletedAgents = ctx.singleton<DeletedAgents>();
for (auto i = 0; i < consts::kMaxAgentCount; i++) {
deletedAgents.deletedAgents[i] = -1;
}
// Creates agents, walls, etc.
createPersistentEntities(ctx);

Expand Down
1 change: 1 addition & 0 deletions src/sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ enum class ExportID : uint32_t {
ResetMap,
WorldMeans,
MetaData,
DeletedAgents,
NumExports
};

Expand Down
4 changes: 4 additions & 0 deletions src/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ namespace gpudrive
int32_t reset;
};

struct DeletedAgents {
int32_t deletedAgents[consts::kMaxAgentCount];
};

struct WorldMeans {
madrona::math::Vector3 mean; // TODO: Z is 0 for now, but can be used for 3D in future
};
Expand Down

0 comments on commit a713a83

Please sign in to comment.