forked from michaelnny/alpha_zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
96 lines (76 loc) · 2.33 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) 2023 Michael Hu.
# This code is part of the book "The Art of Reinforcement Learning: Fundamentals, Mathematics, and Implementation with Python.".
# This project is released under the MIT License.
# See the accompanying LICENSE file for details.
from typing import Mapping, Text, Any
import logging
import sys
import time
import timeit
from collections import deque
def get_time_stamp(file_name: bool = False) -> str:
t = time.localtime()
if file_name:
return time.strftime('%Y%m%d_%H%M%S', t)
else:
return time.strftime('%Y-%m-%d %H:%M:%S', t)
def extract_args_from_flags_dict(flags_dict: Mapping[Text, Any]) -> Mapping[Text, Any]:
# Default arguments from the absl flags
keys_to_exclude = [
'logtostderr',
'alsologtostderr',
'log_dir',
'v',
'verbosity',
'logger_levels',
'stderrthreshold',
'showprefixforinfo',
'run_with_pdb',
'pdb_post_mortem',
'pdb',
'run_with_profiling',
'profile_file',
'use_cprofile_for_profiling',
'only_check_args',
'?',
'help',
'helpshort',
'helpfull',
'helpxml',
]
args = {}
for k, v in flags_dict.items():
if k not in keys_to_exclude:
args[k] = v
return args
def create_logger(level='INFO'):
handler = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
fmt='%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
handler.setFormatter(formatter)
logger = logging.getLogger()
veb = logging.INFO
level = str(level).upper()
if level == 'DEBUG':
veb = logging.DEBUG
logger.setLevel(veb)
logger.addHandler(handler)
return logger
class Timer:
def __init__(self, max_history=100):
self.history = deque(maxlen=max_history)
def __enter__(self):
self.start = timeit.default_timer()
return self
def __exit__(self, *args):
self.history.append(timeit.default_timer() - self.start)
def mean_time(self):
if len(self.history) == 0:
return 0
return sum(self.history) / len(self.history)
def last_time(self):
if len(self.history) == 0:
return 0
return self.history[-1]