-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch rlos2021_cfe to master (#36)
* Initial file structure changes for Package (#16) * Interface Approach2 (#29) * Split ips and snips. Make two classes: Estimator and Interval * Split pseudo_inverse into 2 classes: Estimator and Interval * Fix test_pi, remove redundant file ips_snips, Remove type argument from get_estimate * Slates interface implementation * Cb interface initial commit * Rename file and change class names * Edit doc strings * Change count datatype to float in cb_base * Added gaussian, clopper_pearson files and removed type from cb interface * Add newline at the end of file * Changes for slates - Renamed file from slates_helper to slates_base - Added gaussian.py - Removed type from get_interval - Removed type from get_estimate - Change doc strings for the slates interface - Changed class names - Changed data type of count - Fixed data type of p_log and p_pred - Removed unused imports * Remove redundant imports and code * Change method name to get() * Rename file to base and change class name of ips, snips * Change doc strings and variable name: slates * Changes for test_pi * Cressieread Interval update * Changes folder name and class names (#31) * Minimal changes tobasic-usage (#32) * Improvements for setup.py and slates (#33) * imports fix (#34) * Adding Tests (#35) * Unit tests added * Test for multiple examples * Added test for checking narrowing intervals * Combine all unit test functions into one * Added comments * Added another example generator * Fixed Imports * Change variable names and fix typo * Added check for correct format of Confidence Interval * Separate bandit and slates tests * Move functions to utils * Added test for correctness(slates) * Comments added for test_bandits * Added tests for slates intervals * Move data generators from helper files to test_* files * Remove num_slots as a parameter in util functions * Combine run_estimator function * Combine SlatesHelper and BanditsHelper * Move assert statements from run_estimator() to test_*.py * Move assert statements from Helper() functions to test_*.py file * Improving code consistency * Defined static methods and renamed file to utils.py * Add function assert_is_close to utils * Variable name changed * Restructuring of code * CI improvements (#38) * Added support for Python version 3.9 * CI: Check test coverage * Added interface and module for ccb (#37) * Added ccb estimator (#39) * Added ccb estimator file * Removed type and added Interval() * Added unit test for ccb + code corrections in ccb.py * Test for correctness and narrowing Intervals added * Changed module name * Change variable name * Removed hard coding for specific alpha values in gaussain files (#44) * Add tests (#43) * use random.seed() to make test scenarios reproducible * Change function names * Rename variables * Rename variables listofestimators->estimators and listofintervals->intervals * Renamed variables for test_narrowing_intervals * Added test to check alpha value is not hardcoded for bandits * Renamed to test_different_alpha_CI * Rlos2021 minor cleanup (#45) * minor cleanups * py35 removal * more type hints * snake case * ValueError * snake case Co-authored-by: Alexey Taymanov <[email protected]> Co-authored-by: Alexey Taymanov <[email protected]>
- Loading branch information
1 parent
993c080
commit bb971f6
Showing
31 changed files
with
958 additions
and
279 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#Jupyter notebook checkpoints | ||
**/.ipynb_checkpoints/* | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
*.egg-info | ||
|
||
# Python build artifacts | ||
build/ | ||
dist/ | ||
|
||
#ignored examples files | ||
examples/*.log | ||
|
||
# Editors | ||
.vscode/ | ||
.idea/ | ||
|
||
# Type checking | ||
.mypy_cache | ||
|
||
.coverage |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" Interface for implementation of contextual bandit estimators """ | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
class Estimator(ABC): | ||
""" Interface for implementation of contextual bandit estimators """ | ||
|
||
@abstractmethod | ||
def add_example(self, p_log: float, r: float, p_pred: float, count: float) -> None: | ||
""" | ||
Args: | ||
p_log: probability of the logging policy | ||
r: reward for choosing an action in the given context | ||
p_pred: predicted probability of making decision | ||
count: weight | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def get(self) -> float: | ||
""" Calculates the selected estimator | ||
Returns: | ||
The estimator value | ||
""" | ||
... | ||
|
||
class Interval(ABC): | ||
""" Interface for implementation of contextual bandit estimators interval """ | ||
|
||
@abstractmethod | ||
def add_example(self, p_log: float, r: float, p_pred: float, count: float) -> None: | ||
""" | ||
Args: | ||
p_log: probability of the logging policy | ||
r: reward for choosing an action in the given context | ||
p_pred: predicted probability of making decision | ||
count: weight | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def get(self, alpha: float) -> List[float]: | ||
""" Calculates the CI | ||
Args: | ||
alpha: alpha value | ||
Returns: | ||
Returns the confidence interval as list[float] | ||
""" | ||
... |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import math | ||
from scipy.stats import beta | ||
from estimators.bandits import base | ||
from typing import List | ||
|
||
class Interval(base.Interval): | ||
|
||
def __init__(self): | ||
################################# Aggregates quantities ######################################### | ||
# | ||
# 'n': IPS of numerator | ||
# 'N': total number of samples in bin from log (IPS = n/N) | ||
# 'c': max abs. value of numerator's items (needed for Clopper-Pearson confidence intervals) | ||
# | ||
################################################################################################# | ||
|
||
self.data = {'n':0.,'N':0,'c':0.} | ||
|
||
def add_example(self, p_log: float, r: float, p_pred: float, count: float = 1.0) -> None: | ||
self.data['N'] += count | ||
if p_pred > 0: | ||
p_over_p = p_pred/p_log | ||
if r != 0: | ||
self.data['n'] += r*p_over_p*count | ||
self.data['c'] = max(self.data['c'], r*p_over_p) | ||
|
||
def get(self, alpha: float = 0.05) -> List[float]: | ||
bounds = [] | ||
num = self.data['n'] | ||
den = self.data['N'] | ||
max_weighted_cost = self.data['c'] | ||
|
||
if max_weighted_cost > 0.0: | ||
successes = num / max_weighted_cost | ||
n = den / max_weighted_cost | ||
bounds.append(beta.ppf(alpha / 2, successes, n - successes + 1)) | ||
bounds.append(beta.ppf(1 - alpha / 2, successes + 1, n - successes)) | ||
|
||
if not bounds: | ||
bounds = [0, 0] | ||
return bounds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import math | ||
from estimators.bandits import base | ||
from scipy import stats | ||
from typing import List | ||
|
||
class Interval(base.Interval): | ||
|
||
def __init__(self): | ||
################################# Aggregates quantities ######################################### | ||
# | ||
# 'n': IPS of numerator | ||
# 'N': total number of samples in bin from log (IPS = n/N) | ||
# 'SoS': sum of squares of numerator's items (needed for Gaussian confidence intervals) | ||
# | ||
################################################################################################# | ||
|
||
self.data = {'n':0.,'N':0,'SoS':0} | ||
|
||
def add_example(self, p_log: float, r: float, p_pred: float, count: float = 1.0) -> None: | ||
self.data['N'] += count | ||
if p_pred > 0: | ||
p_over_p = p_pred/p_log | ||
if r != 0: | ||
self.data['n'] += r*p_over_p*count | ||
self.data['SoS'] += ((r*p_over_p)**2)*count | ||
|
||
def get(self, alpha: float = 0.05) -> List[float]: | ||
bounds = [] | ||
num = self.data['n'] | ||
den = self.data['N'] | ||
sum_of_sq = self.data['SoS'] | ||
|
||
if sum_of_sq > 0.0 and den > 1: | ||
z_gaussian_cdf = stats.norm.ppf(1-alpha/2) | ||
|
||
variance = (sum_of_sq - num * num / den) / (den - 1) | ||
gauss_delta = z_gaussian_cdf * math.sqrt(variance/den) | ||
bounds.append(num / den - gauss_delta) | ||
bounds.append(num / den + gauss_delta) | ||
|
||
if not bounds: | ||
bounds = [0, 0] | ||
return bounds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from estimators.bandits import base | ||
|
||
class Estimator(base.Estimator): | ||
|
||
def __init__(self): | ||
################################# Aggregates quantities ######################################### | ||
# | ||
# 'n': IPS of numerator | ||
# 'N': total number of samples in bin from log (IPS = n/N) | ||
# | ||
################################################################################################# | ||
|
||
self.data = {'n':0.,'N':0} | ||
|
||
def add_example(self, p_log: float, r: float, p_pred: float, count: float = 1.0) -> None: | ||
self.data['N'] += count | ||
if p_pred > 0: | ||
p_over_p = p_pred/p_log | ||
if r != 0: | ||
self.data['n'] += r*p_over_p*count | ||
|
||
def get(self) -> float: | ||
if self.data['N'] == 0: | ||
raise ValueError('Error: No data point added') | ||
|
||
return self.data['n']/self.data['N'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from estimators.bandits import base | ||
|
||
class Estimator(base.Estimator): | ||
|
||
def __init__(self): | ||
################################# Aggregates quantities ######################################### | ||
# | ||
# 'n': IPS of numerator | ||
# 'N': total number of samples in bin from log (IPS = n/N) | ||
# 'd': IPS of denominator (SNIPS = n/d) | ||
# | ||
################################################################################################# | ||
|
||
self.data = {'n':0.,'N':0,'d':0.} | ||
|
||
def add_example(self, p_log: float, r: float, p_pred: float, count: float = 1.0) -> None: | ||
self.data['N'] += count | ||
if p_pred > 0: | ||
p_over_p = p_pred/p_log | ||
self.data['d'] += p_over_p*count | ||
if r != 0: | ||
self.data['n'] += r*p_over_p*count | ||
|
||
def get(self) -> float: | ||
if self.data['N'] == 0: | ||
raise ValueError('Error: No data point added') | ||
|
||
if self.data['d'] != 0: | ||
return self.data['n']/self.data['d'] | ||
else: | ||
return 0 |
Oops, something went wrong.