-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcloud_segmentation_tconcord3d
executable file
·170 lines (130 loc) · 7.21 KB
/
cloud_segmentation_tconcord3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python
import os
import rospy
import rospkg
from sensor_msgs.msg import Image, PointCloud2
from ros_numpy import msgify, numpify
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured, unstructured_to_structured
import torch
from threading import RLock
from tconcord3d.builder import model_builder
from tconcord3d.config.config import load_config_data
from tconcord3d.utils.load_save_util import load_checkpoint
# pkg_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', '..'))
pkg_path = rospkg.RosPack().get_path('traversability_estimation')
def msgify_cloud(cloud, frame, stamp, names):
assert cloud.ndim == 2
cloud = unstructured_to_structured(cloud, names=names)
msg = msgify(PointCloud2, cloud)
msg.header.frame_id = frame
msg.header.stamp = stamp
return msg
# transformation between Cartesian coordinates and polar coordinates
def cart2polar(input_xyz):
rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2)
phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0])
return np.stack((rho, phi, input_xyz[:, 2]), axis=1)
def polar2cat(input_xyz_polar):
# print(input_xyz_polar.shape)
x = input_xyz_polar[0] * np.cos(input_xyz_polar[1])
y = input_xyz_polar[0] * np.sin(input_xyz_polar[1])
return np.stack((x, y, input_xyz_polar[2]), axis=0)
class CloudSegmentor:
def __init__(self, cloud_topic='cloud'):
self.lidar_frame = None
self.lock = RLock()
self.device = rospy.get_param('~device', 'cuda')
self.model_weights = rospy.get_param('~weights', "student_kitti_traversablity_f0_0_time_ema.pt")
self.model_path = os.path.join(pkg_path, "config/weights/", "t-concord3d/%s" % self.model_weights)
assert os.path.exists(self.model_path)
model_config_path = os.path.join(pkg_path, 'src/tconcord3d/config/semantickitti/'
'semantickitti_S0_0_T11_33_ssl_s20_p80.yaml')
self.model_config = load_config_data(model_config_path)['model_params']
self.grid_size = np.asarray(self.model_config['output_shape'])
self.model = self.load_model()
rospy.loginfo('Loaded cloud segmentation model: %s', self.model_weights)
self.input_pc_fields = ['x', 'y', 'z']
self.output_pc_fields = ['x', 'y', 'z', 'cost']
self.debug = rospy.get_param('~debug', False)
# point cloud which time stamp is older is not being processed
self.max_age = rospy.get_param('~max_age', 0.5)
self.fixed_volume_space = rospy.get_param('~fixed_volume_space', True)
self.max_volume_space = rospy.get_param('~max_volume_space', [50, np.pi, 2])
self.min_volume_space = rospy.get_param('~min_volume_space', [0, -np.pi, -4])
self.segm_cloud_pub = rospy.Publisher(rospy.get_param('~cloud_out', '~points'), PointCloud2, queue_size=1)
self.depth_pub = rospy.Publisher('~depth', Image, queue_size=1)
self.cloud_sub = rospy.Subscriber(cloud_topic, PointCloud2, self.segment_cloud_cb)
rospy.loginfo('Point cloud segmentation node is ready.')
def load_model(self):
model = model_builder.build(self.model_config).to(self.device)
model = load_checkpoint(self.model_path, model)
model = model.eval()
return model
def segment_cloud_cb(self, pc_msg):
assert isinstance(pc_msg, PointCloud2)
self.lidar_frame = pc_msg.header.frame_id
# Discard old messages.
msg_stamp = rospy.Time.now()
age = (msg_stamp - pc_msg.header.stamp).to_sec()
if age > self.max_age:
rospy.logwarn('Cloud segmentation: Discarding points %.1f s > %.1f s old.', age, self.max_age)
return
t0 = rospy.Time.now().to_sec()
with self.lock:
cloud = numpify(pc_msg)
xyz = np.array(structured_to_unstructured(cloud[self.input_pc_fields]))
rospy.logdebug('Point cloud of shape %s is received', xyz.shape)
xyz = xyz.reshape((-1, 3))
# xyz[:, 2] += 0.46
xyz_pol = cart2polar(xyz)
max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0)
min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0)
max_bound = np.max(xyz_pol[:, 1:], axis=0)
min_bound = np.min(xyz_pol[:, 1:], axis=0)
max_bound = np.concatenate(([max_bound_r], max_bound))
min_bound = np.concatenate(([min_bound_r], min_bound))
if self.fixed_volume_space:
max_bound = np.asarray(self.max_volume_space)
min_bound = np.asarray(self.min_volume_space)
# get grid index
crop_range = max_bound - min_bound
cur_grid_size = self.grid_size
intervals = crop_range / (cur_grid_size - 1)
if (intervals == 0).any():
print("Zero interval!")
grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound) - min_bound) / intervals)).astype(int)
# center data on each voxel for PTnet
voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
cylindr_xyz = xyz_pol - voxel_centers
cylindr_xyz = np.concatenate((cylindr_xyz, xyz_pol, xyz[:, :2]), axis=1)
sig = np.zeros((len(cylindr_xyz), 1))
cylindr_features = np.concatenate((cylindr_xyz, sig), axis=1)
cylindr_features = torch.as_tensor(cylindr_features, dtype=torch.float32).to(self.device)
grid_ind_tensor = torch.as_tensor(grid_ind, dtype=torch.float32).to(self.device)
with torch.no_grad():
rospy.logdebug('Input shapes: %s, %s', cylindr_features.shape, xyz.shape)
t0 = rospy.Time.now().to_sec()
predict_labels_raw = self.model([cylindr_features], [grid_ind_tensor], batch_size=1)
rospy.loginfo('Inference took: %.3f [sec]', rospy.Time.now().to_sec() - t0)
rospy.logdebug('Output shape: %s', predict_labels_raw.shape)
# predict_labels = torch.argmax(predict_labels_raw, dim=1)
# print(torch.unique(predict_labels, return_counts=True))
predict_probability = torch.softmax(predict_labels_raw, dim=1)
# predict_labels = predict_labels.cpu().numpy().squeeze(0)
predict_probability = predict_probability.cpu().numpy().squeeze(0)
# predict_label = predict_labels[grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2]]
predict_prob = predict_probability[:, grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2]]
# predict_labels = np.array(predict_label, dtype=np.float32)
predict_prob = np.array(predict_prob, dtype=np.float32)
cost = predict_prob[1].reshape((-1, 1)) # traversability index
xyz_cost = np.concatenate([xyz, cost], axis=1)
segm_pc_msg = msgify_cloud(xyz_cost, frame=self.lidar_frame, stamp=rospy.Time.now(),
names=self.output_pc_fields)
self.segm_cloud_pub.publish(segm_pc_msg)
def main():
rospy.init_node('cloud_segmentation_tconcord3d', log_level=rospy.DEBUG)
proc = CloudSegmentor(cloud_topic=rospy.get_param('~cloud_in', 'points'))
rospy.spin()
if __name__ == '__main__':
main()