Skip to content

Commit

Permalink
Add functionality to enable estimators for CATS (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou authored Mar 24, 2021
1 parent 4d05624 commit 993c080
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 5 deletions.
38 changes: 35 additions & 3 deletions basic-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import ips_snips
import mle
import ds_parse
import cats_utils


def compute_estimates(log_fp):
def compute_estimates(log_fp, cats_transformer=None):
# Init estimators
online = ips_snips.Estimator()
baseline1 = ips_snips.Estimator()
Expand Down Expand Up @@ -54,6 +55,36 @@ def compute_estimates(log_fp):

evts += 1

if x.startswith(b'{"_label_ca":') and x.strip().endswith(b'}'):
data = ds_parse.json_cooked_continuous_actions(x)
if cats_transformer is None:
raise RuntimeError("Not all of the required arguments for running with continuous actions have been provided.")
# passing logged action as predicted action to transformer
data = cats_transformer.transform(data, data['a'])
# passing baseline action as predicted action to transformer
data_baseline1 = cats_transformer.transform(data, cats_transformer.get_baseline1_prediction())

if data['skipLearn']:
continue

r = 0 if data['cost'] == b'0' else -float(data['cost'])

# Update estimators with tuple (p_log, r, p_pred)
online.add_example(data['p'], r, data['p'])
baseline1.add_example(data['p'], r, data_baseline1['pred_p'])
baselineR.add_example(data['p'], r, 1.0 / cats_transformer.continuous_range)

online_mle.add_example(data['p'], r, data['p'])
baseline1_mle.add_example(data['p'], r, data_baseline1['pred_p'])
baselineR_mle.add_example(data['p'], r, 1.0 / cats_transformer.continuous_range)

online_cressieread.add_example(data['p'], r, data['p'])
baseline1_cressieread.add_example(data['p'], r, data_baseline1['pred_p'])
baselineR_cressieread.add_example(data['p'], r, 1.0 / cats_transformer.continuous_range)

evts += 1


if log_fp.endswith('.gz'):
len_text = ds_parse.update_progress(i+1)
else:
Expand Down Expand Up @@ -88,7 +119,8 @@ def compute_estimates(log_fp):

parser = argparse.ArgumentParser()
parser.add_argument('-l','--log_fp', help="data file path (.json or .json.gz format - each line is a dsjson)", required=True)

parser = cats_utils.set_custom_args(parser)
args = parser.parse_args()
cats_transformer = cats_utils.get_cats_transformer(args)

compute_estimates(args.log_fp)
compute_estimates(args.log_fp, cats_transformer)
42 changes: 42 additions & 0 deletions cats_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse
import math

def set_custom_args(parser):
parser.add_argument('--max_value', help="[CATS estimator] max value for continuous action range", required=False)
parser.add_argument('--min_value', help="[CATS estimator] min value for continuous action range", required=False)
parser.add_argument('--num_actions', help="[CATS estimator] number of actions used to discretize continuous range", required=False)
parser.add_argument('--bandwidth', help="[CATS estimator] bandwidth (radius) of randomization around discrete actions in terms of continuous range ", required=False)
return parser

def get_cats_transformer(args):
if args.num_actions and args.max_value and args.min_value and args.bandwidth:
return CatsTransformer(args.num_actions, args.bandwidth, args.max_value, args.min_value)
else:
return

class CatsTransformer:
def __init__(self, num_actions, bandwidth, max_value, min_value):
self.num_actions = int(num_actions)
self.max_value = float(max_value)
self.min_value = float(min_value)
self.bandwidth = float(bandwidth)

self.continuous_range = self.max_value - self.min_value
self.unit_range = self.continuous_range / float(self.num_actions)

def get_baseline1_prediction(self):
return self.min_value + (self.unit_range / 2.0)

def transform(self, data, pred_a):
logged_a = data['a']

ctr = min((self.num_actions - 1), math.floor((pred_a - self.min_value) / self.unit_range))
centre = self.min_value + ctr * self.unit_range + (self.unit_range / 2.0)

if(math.isclose(centre, logged_a, abs_tol=self.bandwidth)):
b = min(self.max_value, centre + self.bandwidth) - max(self.min_value, centre - self.bandwidth)
data['pred_p'] = 1.0 / b
else:
data['pred_p'] = 0.0

return data
21 changes: 21 additions & 0 deletions ds_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,25 @@ def json_cooked(x):
data['num_a'] = len(data['a_vec'])
data['skipLearn'] = b'"_skipLearn":true' in x[ind2+34:ind3] # len('"_label_Action":1,"_labelIndex":0,') = 34

return data

def json_cooked_continuous_actions(x):
#################################
# Optimized version based on expected structure:
# {"_label_ca":{"cost":0,"pdf_value":0.0181818,"action":185.5},"Timestamp":"2017-10-24T00:00:15.5160000Z","Version":"1","EventId":"fa68cd9a71764118a635fd3d7a908634","c":{}}"
# Assumption: "Version" value is 1 digit string
#
# Performance: 4x faster than Python JSON parser js = json.loads(x.strip())
#################################
ind1 = x.find(b',',22) # equal to: x.find(',"pdf_value',16)
ind2 = x.find(b',',ind1+13) # equal to: x.find(',"action',ind1+23)
ind3 = x.find(b'}',ind2+10)
ind4 = x.find(b',"T',ind3+34) # equal to: x.find(',"Timestamp',ind2+34)

data = {}
data['cost'] = float(x[21:ind1]) # len('{"_label_ca":"cost":') = 21
data['p'] = float(x[ind1+13:ind2]) # len(',"pdf_value":') = 13
data['a'] = float(x[ind2+10:ind3])
data['skipLearn'] = b'"_skipLearn":true' in x

return data
7 changes: 5 additions & 2 deletions ips_snips.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def get_estimate(self, type):
if type == 'ips':
return self.data['n']/self.data['N']
elif type == 'snips':
return self.data['n']/self.data['d']
if self.data['d'] != 0:
return self.data['n']/self.data['d']
else:
return 0
else:
raise('Error: Incorrect estimator type {}. Supported options are ips or snips'.format(type))

Expand All @@ -54,7 +57,7 @@ def get_interval(self, type, alpha=0.05):
bounds.append(beta.ppf(alpha / 2, successes, n - successes + 1))
bounds.append(beta.ppf(1 - alpha / 2, successes + 1, n - successes))
elif type == "gaussian":
if SoS > 0.0:
if SoS > 0.0 and den > 1:
zGaussianCdf = {
0.25: 1.15,
0.1: 1.645,
Expand Down
69 changes: 69 additions & 0 deletions test/test_pi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pseudo_inverse
import ips_snips
import cats_utils

def test_single_slot_pi_equivalent_to_ips():
"""PI should be equivalent to IPS when there is only a single slot"""
Expand All @@ -18,3 +19,71 @@ def test_single_slot_pi_equivalent_to_ips():
pi_estimator.add_example([p_log], r, [p_pred])
ips_estimator.add_example(p_log, r, p_pred)
assert is_close(pi_estimator.get_estimate('pi') , ips_estimator.get_estimate('ips'))


def test_cats_ips():
ips_estimator = ips_snips.Estimator()

prob_logs = [0.151704, 0.006250, 0.086, 0.086, 0.086]
action_logs = [15.0, 3.89, 22.3, 17.34, 31]
rewards = [0.1, 0.2, 0, 1.0, 1.0]

max_value = 32
bandwidth = 1
cats_transformer = cats_utils.CatsTransformer(num_actions=8, min_value=0, max_value=max_value, bandwidth=bandwidth)

for logged_action, r, logged_prob in zip(action_logs, rewards, prob_logs):
data = {}
data['a'] = logged_action
data['cost'] = r
data['p'] = logged_prob
if logged_action < (max_value / 2.0):
pred_action = logged_action + 2 * bandwidth
data = cats_transformer.transform(data, pred_action) # pred_action should be too far away, so pred_p should be 0
assert data['pred_p'] == 0.0
else:
pred_action = logged_action
data = cats_transformer.transform(data, logged_action) # same action, so pred_p should be 1
assert data['pred_p'] == 1.0 / (2 * bandwidth)

ips_estimator.add_example(data['p'], r, data['pred_p'])
assert ips_estimator.get_estimate('ips') >= ips_estimator.get_estimate('snips')

def test_cats_transformer_on_edges():
prob_logs = [0.151704, 0.006250, 0.086, 0.086]
action_logs = [0, 1, 31, 32]
rewards = [1.0, 1.0, 1.0, 1.0]

max_value = 32
bandwidth = 2
cats_transformer = cats_utils.CatsTransformer(num_actions=8, min_value=0, max_value=max_value, bandwidth=bandwidth)

for logged_action, r, logged_prob in zip(action_logs, rewards, prob_logs):
data = {}
data['a'] = logged_action
data['cost'] = r
data['p'] = logged_prob

pred_action = logged_action
data = cats_transformer.transform(data, logged_action) # same action, so pred_p should be 1
assert data['pred_p'] == 1.0 / (2 * bandwidth)


def test_cats_baseline():
max_value = 32
min_value = 0
bandwidth = 1
num_actions = 8
cats_transformer = cats_utils.CatsTransformer(num_actions=num_actions, min_value=min_value, max_value=max_value, bandwidth=bandwidth)
baseline = cats_transformer.get_baseline1_prediction()
## unit range is 4, min_value is 0 so baseline action should be the centre of the firt unit range, starting off from min_value i.e. 2
assert baseline == 2

max_value = 33
min_value = 1
bandwidth = 1
num_actions = 8
cats_transformer = cats_utils.CatsTransformer(num_actions=num_actions, min_value=min_value, max_value=max_value, bandwidth=bandwidth)
baseline = cats_transformer.get_baseline1_prediction()
## unit range is 4, min_value is 1 so baseline action should be the centre of the firt unit range, starting off from min_value i.e. 3
assert baseline == 3

0 comments on commit 993c080

Please sign in to comment.