-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTFR Record.py
91 lines (79 loc) · 3.14 KB
/
TFR Record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import io
import os
import tensorflow.compat.v1 as tf
from PIL import Image
from waymo_open_dataset import dataset_pb2 as open_dataset
from utils import parse_frame, int64_feature, int64_list_feature, bytes_feature \
bytes_list_feature, float_list_feature
def create_tf_example(filename, encoded_jpeg, annotations):
"""
convert to tensorflow object detection API format
args:
- filename [str]: name of the image
- encoded_jpeg [bytes-likes]: encoded image
- annotations [list]: bboxes and classes
returns:
- tf_example [tf.Example]
"""
encoded_jpg_io = io.BytesIO(encoded_jpeg)
image = Image.open(encoded_jpg_io)
width, height = image.size
mapping = {1: 'vehicle', 2: 'pedestrian', 4: 'cyclist'}
image_format = b'jpg'
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
filename = filename.encode('utf8')
for ann in annotations:
xmin, ymin = ann.box.center_x - 0.5 * ann.box.length,
ann.box.center_y - 0.5 * ann.box.width
xmax, ymax = ann.box.center_x + 0.5 * ann.box.length,
ann.box.center_y + 0.5 * ann.box.width
xmins.append(xmin / width)
xmaxs.append(xmax / width)
ymins.append(ymin / height)
ymaxs.append(ymax / height)
classes.append(ann.type)
classes_text.append(mapping[ann.type].encode('utf8'))
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(filename),
'image/source_id': bytes_feature(filename),
'image/encoded': bytes_feature(encoded_jpeg),
'image/format': bytes_feature(image_format),
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
}))
return tf_example
def process_tfr(path):
"""
process a waymo tf record into a tf api tf record
"""
# create processed data dir
file_name = os.path.basename(path)
logger.info(f'Processing {path}')
writer = tf.python_io.TFRecordWriter(f'{dest}/{file_name}')
dataset = tf.data.TFRecordDataset(path, compression_type='')
for idx, data in enumerate(dataset):
frame = open_dataset.Frame()
frame.ParseFromString(bytearray(data.numpy()))
encoded_jpeg, annotations = parse_frame(frame)
filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
tf_example = create_tf_example(filename, encoded_jpeg, annotations)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', required=True, type=str,
help='Waymo Open dataset tf record')
args = parser.parse_args()
process_tfr(args.path)
;