-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathtb-rename-events.py
executable file
·64 lines (59 loc) · 2.12 KB
/
tb-rename-events.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# this script renames event names in tensorboard log files
# it does the rename in place (so make back ups!)
#
# example:
#
# find . -name "*.tfevents*" -exec tb-rename-events.py {} "iteration-time" "iteration-time/iteration-time" \;
#
# more than one old tag can be remapped to one new tag - use `;` as a separator:
#
# tb-rename-events.py events.out.tfevents.1 "training loss;validation loss" "loss"
#
# this script is derived from https://stackoverflow.com/a/60080531/9201239
#
# Important: this script requires CUDA environment.
import shlex
import sys
from pathlib import Path
import os
# avoid using the GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.core.util.event_pb2 import Event
def rename_events(input_file, old_tags, new_tag):
new_file = input_file + ".new"
# Make a record writer
with tf.io.TFRecordWriter(new_file) as writer:
# Iterate event records
for rec in tf.data.TFRecordDataset([input_file]):
# Read event
ev = Event()
ev.MergeFromString(rec.numpy())
# Check if it is a summary
#print(ev)
if ev.summary:
# Iterate summary values
for v in ev.summary.value:
#print(v)
# Check if the tag should be renamed
if v.tag in old_tags:
# Rename with new tag name
v.tag = new_tag
writer.write(ev.SerializeToString())
os.rename(new_file, input_file)
def rename_events_dir(input_file, old_tags, new_tag):
# Write renamed events
rename_events(input_file, old_tags, new_tag)
if __name__ == '__main__':
if len(sys.argv) != 4:
print(f'{sys.argv[0]} <input file> <old tags> <new tag>',
file=sys.stderr)
sys.exit(1)
input_file, old_tags, new_tag = sys.argv[1:]
print(input_file, shlex.quote(old_tags), shlex.quote(new_tag))
old_tags = old_tags.split(';')
rename_events_dir(input_file, old_tags, new_tag)