-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add type hints, refactored codes, add code comments in sel_bw.py
- Loading branch information
Showing
1 changed file
with
62 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# GWR Bandwidth selection class | ||
|
||
#x_glob parameter does not yet do anything; it is for semiparametric | ||
# x_glob parameter does not yet do anything; it is for semiparametric | ||
|
||
__author__ = "Taylor Oshan [email protected]" | ||
|
||
|
@@ -10,7 +10,7 @@ | |
from scipy.spatial.distance import pdist | ||
from scipy.optimize import minimize_scalar | ||
from spglm.family import Gaussian, Poisson, Binomial | ||
from .kernels import Kernel,local_cdist | ||
from .kernels import Kernel, local_cdist | ||
from .gwr import GWR | ||
from .search import golden_section, equal_interval, multi_bw | ||
from .diagnostics import get_AICc, get_AIC, get_BIC, get_CV | ||
|
@@ -141,7 +141,7 @@ class Sel_BW(object): | |
>>> pov = np.array(data.by_col('PctPov')).reshape((-1,1)) | ||
>>> african_amer = np.array(data.by_col('PctBlack')).reshape((-1,1)) | ||
>>> X = np.hstack([rural, pov, african_amer]) | ||
Golden section search AICc - adaptive bisquare | ||
>>> bw = Sel_BW(coords, y, X).search(criterion='AICc') | ||
|
@@ -175,36 +175,51 @@ class Sel_BW(object): | |
""" | ||
|
||
def __init__(self, coords, y, X_loc, X_glob=None, family=Gaussian(), | ||
offset=None, kernel='bisquare', fixed=False, multi=False, | ||
constant=True, spherical=False,n_jobs=-1): | ||
def __init__(self, | ||
coords: list[tuple], | ||
y: np.array, | ||
X_loc: np.array, | ||
X_glob: np.array = None, | ||
family=Gaussian(), | ||
offset: np.array = None, | ||
kernel: str = 'bisquare', | ||
fixed: bool = False, | ||
multi: bool = False, | ||
constant: bool = True, | ||
spherical: bool = False, | ||
n_jobs: int = -1) -> None: | ||
self.coords = np.array(coords) | ||
self.y = y | ||
self.X_loc = X_loc | ||
if X_glob is not None: | ||
self.X_glob = X_glob | ||
else: | ||
self.X_glob = [] | ||
self.X_glob = X_glob if X_glob is not None else [] | ||
self.family = family | ||
self.fixed = fixed | ||
self.kernel = kernel | ||
if offset is None: | ||
self.offset = np.ones((len(y), 1)) | ||
else: | ||
self.offset = offset * 1.0 | ||
self.offset = np.ones((len(y), 1)) if offset is None else offset * 1.0 | ||
self.multi = multi | ||
self._functions = [] | ||
self.constant = constant | ||
self.spherical = spherical | ||
self.n_jobs = n_jobs | ||
self.search_params = {} | ||
|
||
def search(self, search_method='golden_section', criterion='AICc', | ||
bw_min=None, bw_max=None, interval=0.0, tol=1.0e-6, | ||
max_iter=200, init_multi=None, tol_multi=1.0e-5, | ||
rss_score=False, max_iter_multi=200, multi_bw_min=[None], | ||
multi_bw_max=[None | ||
], bws_same_times=5, verbose=False,pool=None): | ||
def search(self, | ||
search_method: str = 'golden_section', | ||
criterion: str = 'AICc', | ||
bw_min: float = None, | ||
bw_max: float = None, | ||
interval: int = 0.0, | ||
tol: float = 1.0e-6, | ||
max_iter: int = 200, | ||
init_multi: float = None, | ||
tol_multi: float = 1.0e-5, | ||
rss_score: bool = False, | ||
max_iter_multi: int = 200, | ||
multi_bw_min: list = [None], | ||
multi_bw_max: list = [None], | ||
bws_same_times: int = 5, | ||
verbose: bool = False, | ||
pool: int = None): | ||
""" | ||
Method to select one unique bandwidth for a gwr model or a | ||
bandwidth vector for a mgwr model. | ||
|
@@ -219,7 +234,7 @@ def search(self, search_method='golden_section', criterion='AICc', | |
min value used in bandwidth search | ||
bw_max : float | ||
max value used in bandwidth search | ||
multi_bw_min : list | ||
multi_bw_min : list | ||
min values used for each covariate in mgwr bandwidth search. | ||
Must be either a single value or have one value for | ||
each covariate including the intercept | ||
|
@@ -263,7 +278,7 @@ def search(self, search_method='golden_section', criterion='AICc', | |
designs matrix, X | ||
""" | ||
k = self.X_loc.shape[1] | ||
if self.constant: #k is the number of covariates | ||
if self.constant: # k is the number of covariates | ||
k += 1 | ||
self.search_method = search_method | ||
self.criterion = criterion | ||
|
@@ -295,7 +310,6 @@ def search(self, search_method='golden_section', criterion='AICc', | |
if pool: | ||
warnings.warn("The pool parameter is no longer used and will have no effect; parallelization is default and implemented using joblib instead.", RuntimeWarning, stacklevel=2) | ||
|
||
|
||
self.interval = interval | ||
self.tol = tol | ||
self.max_iter = max_iter | ||
|
@@ -310,16 +324,15 @@ def search(self, search_method='golden_section', criterion='AICc', | |
self.search_params['interval'] = interval | ||
self.search_params['tol'] = tol | ||
self.search_params['max_iter'] = max_iter | ||
#self._check_min_max() | ||
# self._check_min_max() | ||
|
||
self.int_score = not self.fixed | ||
|
||
if self.multi: | ||
self._mbw() | ||
self.params = self.bw[3] #params n by k | ||
self.sel_hist = self.bw[-2] #bw searching history | ||
self.bw_init = self.bw[ | ||
-1] #scalar, optimal bw from initial gwr model | ||
self.params = self.bw[3] # params n by k | ||
self.sel_hist = self.bw[-2] # bw searching history | ||
self.bw_init = self.bw[-1] # scalar, optimal bw from initial gwr model | ||
else: | ||
self._bw() | ||
self.sel_hist = self.bw[-1] | ||
|
@@ -337,7 +350,7 @@ def _bw(self): | |
if self.search_method == 'golden_section': | ||
a, c = self._init_section(self.X_glob, self.X_loc, self.coords, | ||
self.constant) | ||
delta = 0.38197 #1 - (np.sqrt(5.0)-1.0)/2.0 | ||
delta = 0.38197 # 1 - (np.sqrt(5.0)-1.0)/2.0 | ||
self.bw = golden_section(a, c, delta, gwr_func, self.tol, | ||
self.max_iter, self.bw_max, self.int_score, | ||
self.verbose) | ||
|
@@ -359,22 +372,27 @@ def _bw(self): | |
self.search_method) | ||
|
||
def _mbw(self): | ||
|
||
# TODO: Do we need to assign these self variables to local variables here? | ||
# TODO: These local variables refer to the same ram locations as it is not a deepcopy. | ||
# TODO: Recommend to use self variables directly in gwr_func, bw_func, and sel_func. | ||
|
||
y = self.y | ||
if self.constant: | ||
X,keep_x,warn = USER.check_constant(self.X_loc) | ||
X, keep_x, warn = USER.check_constant(self.X_loc) | ||
else: | ||
X = self.X_loc | ||
n, k = X.shape | ||
family = self.family | ||
offset = self.offset | ||
kernel = self.kernel | ||
fixed = self.fixed | ||
spherical = self.spherical | ||
# spherical = self.spherical, TODO: Need to delete this line as it is not used | ||
coords = self.coords | ||
search_method = self.search_method | ||
criterion = self.criterion | ||
bw_min = self.bw_min | ||
bw_max = self.bw_max | ||
# bw_min = self.bw_min TODO: Need to delete this line as it is not used | ||
# bw_max = self.bw_max TODO: Need to delete this line as it is not used | ||
multi_bw_min = self.multi_bw_min | ||
multi_bw_max = self.multi_bw_max | ||
interval = self.interval | ||
|
@@ -385,13 +403,13 @@ def _mbw(self): | |
def gwr_func(y, X, bw): | ||
return GWR(coords, y, X, bw, family=family, kernel=kernel, | ||
fixed=fixed, offset=offset, constant=False, | ||
spherical=self.spherical, hat_matrix=False,n_jobs=self.n_jobs).fit( | ||
spherical=self.spherical, hat_matrix=False, n_jobs=self.n_jobs).fit( | ||
lite=True) | ||
|
||
def bw_func(y, X): | ||
selector = Sel_BW(coords, y, X, X_glob=[], family=family, | ||
kernel=kernel, fixed=fixed, offset=offset, | ||
constant=False, spherical=self.spherical,n_jobs=self.n_jobs) | ||
constant=False, spherical=self.spherical, n_jobs=self.n_jobs) | ||
return selector | ||
|
||
def sel_func(bw_func, bw_min=None, bw_max=None): | ||
|
@@ -405,32 +423,23 @@ def sel_func(bw_func, bw_min=None, bw_max=None): | |
bw_func, sel_func, multi_bw_min, multi_bw_max, | ||
bws_same_times, verbose=self.verbose) | ||
|
||
def _init_section(self, X_glob, X_loc, coords, constant): | ||
if len(X_glob) > 0: | ||
n_glob = X_glob.shape[1] | ||
else: | ||
n_glob = 0 | ||
if len(X_loc) > 0: | ||
n_loc = X_loc.shape[1] | ||
else: | ||
n_loc = 0 | ||
if constant: | ||
n_vars = n_glob + n_loc + 1 | ||
else: | ||
n_vars = n_glob + n_loc | ||
def _init_section(self, X_glob, X_loc, coords, constant) -> tuple: | ||
n_glob = X_glob.shape[1] if len(X_glob) > 0 else 0 | ||
n_loc = X_loc.shape[1] if len(X_loc) > 0 else 0 | ||
n_vars = n_glob + n_loc + 1 if constant else n_glob + n_loc | ||
n = np.array(coords).shape[0] | ||
|
||
if self.int_score: | ||
a = 40 + 2 * n_vars | ||
c = n | ||
else: | ||
min_dist = np.min(np.array([np.min(np.delete( | ||
local_cdist(coords[i],coords,spherical=self.spherical),i)) | ||
local_cdist(coords[i], coords, spherical=self.spherical), i)) | ||
for i in range(n)])) | ||
max_dist = np.max(np.array([np.max( | ||
local_cdist(coords[i],coords,spherical=self.spherical)) | ||
local_cdist(coords[i], coords, spherical=self.spherical)) | ||
for i in range(n)])) | ||
|
||
a = min_dist / 2.0 | ||
c = max_dist * 2.0 | ||
|
||
|
@@ -439,4 +448,5 @@ def _init_section(self, X_glob, X_loc, coords, constant): | |
if self.bw_max is not None and self.bw_max is not np.inf: | ||
c = self.bw_max | ||
|
||
return a, c | ||
# use tuple or list in the return if multiple outputs are needed | ||
return (a, c) |