-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding SpatialTemporalStatsTool to ush
- Loading branch information
Showing
2 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
import xarray | ||
import numpy as np | ||
import geopandas as gpd | ||
from shapely.geometry import Polygon, Point | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
from datetime import datetime | ||
import os | ||
|
||
class SpatialTemporalStats: | ||
def __init__(self): | ||
self.grid_gdf = None | ||
self.obs_gdf = None | ||
|
||
def generate_grid(self, resolution=1): | ||
self.resolution = resolution | ||
# Generate the latitude and longitude values using meshgrid | ||
grid_lons, grid_lats = np.meshgrid(np.arange(-180, 181, resolution), | ||
np.arange(-90, 91, resolution)) | ||
|
||
# Flatten the arrays to get coordinates | ||
grid_coords = np.vstack([grid_lons.flatten(), grid_lats.flatten()]).T | ||
|
||
# Create a GeoDataFrame from the coordinates | ||
self.grid_gdf = gpd.GeoDataFrame(geometry=[Polygon([(lon, lat), (lon + resolution, lat), | ||
(lon + resolution, lat + resolution), (lon, lat + resolution)]) | ||
for lon, lat in grid_coords], | ||
crs='EPSG:4326') # CRS for WGS84 | ||
self.grid_gdf["grid_id"]=np.arange(1,len(self.grid_gdf)+1) | ||
|
||
|
||
def _extract_date_times(self,filenames): | ||
date_times = [] | ||
for filename in filenames: | ||
# Split the filename by '.' to get the parts | ||
parts = filename.split('.') | ||
|
||
# Extract the last part which contains the date/time information | ||
date_time_part = parts[-2] | ||
|
||
# The date/time format in the filename seems to be 'YYYYMMDDHH', so we can parse it accordingly | ||
year = int(date_time_part[:4]) | ||
month = int(date_time_part[4:6]) | ||
day = int(date_time_part[6:8]) | ||
hour = int(date_time_part[8:10]) | ||
|
||
# Construct the datetime object | ||
date_time = datetime(year, month, day, hour) | ||
|
||
date_times.append(date_time) | ||
|
||
return date_times | ||
|
||
|
||
def read_obs_values(self, obs_files_path, sensor,var_name, channel_no, start_date, end_date,filter_by_vars,QC_filter): | ||
self.sensor=sensor | ||
self.channel_no=channel_no | ||
#read all obs files | ||
all_files = os.listdir(obs_files_path) | ||
#obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')] | ||
obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4') and "diag_%s_ges" % sensor in file] | ||
|
||
# get date time from file names. alternatively could get from attribute but that needs reading the entire nc4 | ||
files_date_times_df=pd.DataFrame() | ||
|
||
files_date_times=self._extract_date_times(obs_files) | ||
files_date_times_df["file_name"]=obs_files | ||
files_date_times_df["date_time"]=files_date_times | ||
files_date_times_df['date'] = pd.to_datetime(files_date_times_df['date_time'].dt.date) | ||
|
||
#read start date | ||
start_date = datetime.strptime(start_date, '%Y-%m-%d') | ||
end_date = datetime.strptime(end_date, '%Y-%m-%d') | ||
|
||
studied_cycle_files=files_date_times_df[((files_date_times_df["date"]>=start_date) & ((files_date_times_df["date"]<=end_date)))]["file_name"] | ||
|
||
studied_gdf_list=[] | ||
for this_cycle_obs_file in studied_cycle_files: | ||
ds=xarray.open_dataset(this_cycle_obs_file) | ||
|
||
Combined_bool=ds["Channel_Index"].data==channel_no | ||
|
||
if QC_filter: | ||
QC_bool=ds["QC_Flag"].data==0 | ||
Combined_bool=Combined_bool*QC_bool | ||
|
||
#apply filters by variable | ||
for this_filter in filter_by_vars: | ||
filter_var_name, filter_operation, filter_value = this_filter | ||
if filter_operation=='lt': | ||
this_filter_bool=ds[filter_var_name].data<=filter_value | ||
else: | ||
this_filter_bool=ds[filter_var_name].data>=filter_value | ||
Combined_bool=Combined_bool*~this_filter_bool #here we have to negate the above bool to make it right | ||
|
||
this_cycle_var_values=ds[var_name].data[Combined_bool] | ||
this_cycle_lat_values=ds["Latitude"].data[Combined_bool] | ||
this_cycle_long_values=ds["Longitude"].data[Combined_bool] | ||
this_cycle_long_values=np.where(this_cycle_long_values<=180, this_cycle_long_values,this_cycle_long_values-360) | ||
|
||
geometry = [Point(xy) for xy in zip(this_cycle_long_values, this_cycle_lat_values)] | ||
|
||
# Create a GeoDataFrame | ||
this_cycle_gdf = gpd.GeoDataFrame(geometry=geometry, crs='EPSG:4326') | ||
this_cycle_gdf["value"]=this_cycle_var_values | ||
|
||
studied_gdf_list.append(this_cycle_gdf) | ||
|
||
|
||
studied_gdf = pd.concat(studied_gdf_list) | ||
|
||
# Perform spatial join | ||
joined_gdf = gpd.sjoin(studied_gdf, self.grid_gdf, op='within', how='right') | ||
|
||
# Calculate average values of points in each polygon | ||
self.obs_gdf=self.grid_gdf.copy() | ||
self.obs_gdf[var_name +'_Average'] = joined_gdf.groupby('grid_id')['value'].mean() | ||
self.obs_gdf[var_name + '_RMS'] = (joined_gdf.groupby('grid_id')['value'].apply(lambda x: np.sqrt((x**2).mean()))) | ||
self.obs_gdf[var_name +"_Count"] = joined_gdf.groupby('grid_id')['value'].count() | ||
|
||
return self.obs_gdf | ||
|
||
def filter_by_variable(self, var_name, comparison_type, value): | ||
pass | ||
|
||
def plot_obs(self, selected_var_gdf, var_name, region, resolution,output_path): | ||
self.resolution = resolution | ||
var_names = [var_name + '_Average', var_name + "_Count", var_name + '_RMS'] | ||
|
||
for idx, item in enumerate(var_names, start=1): | ||
plt.figure(figsize=(12, 8)) | ||
ax = plt.subplot(1, 1, 1) | ||
|
||
|
||
|
||
if region == 1: | ||
# Plotting global region (no need for filtering) | ||
title = 'Global Region' | ||
filtered_gdf=selected_var_gdf | ||
|
||
elif region == 2: | ||
# Plotting polar region (+60 latitude and above) | ||
title = 'Polar Region (+60 latitude and above)' | ||
filtered_gdf = selected_var_gdf[ | ||
selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_polar_region(geom, 60))] | ||
|
||
elif region == 3: | ||
# Plotting northern mid-latitudes region (20 to 60 latitude) | ||
title = 'Northern Mid-latitudes Region (20 to 60 latitude)' | ||
filtered_gdf = selected_var_gdf[ | ||
selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_latitude_range(geom, 20, 60))] | ||
|
||
elif region == 4: | ||
# Plotting tropics region (-20 to 20 latitude) | ||
title = 'Tropics Region (-20 to 20 latitude)' | ||
filtered_gdf = selected_var_gdf[ | ||
selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_latitude_range(geom, -20, 20))] | ||
|
||
elif region == 5: | ||
# Plotting southern mid-latitudes region (-60 to -20 latitude) | ||
title = 'Southern Mid-latitudes Region (-60 to -20 latitude)' | ||
filtered_gdf = selected_var_gdf[ | ||
selected_var_gdf.geometry.apply(lambda geom: self.is_polygon_in_latitude_range(geom, -60, -20))] | ||
|
||
elif region == 6: | ||
# Plotting southern polar region (less than -60 latitude) | ||
title = 'Southern Polar Region (less than -60 latitude)' | ||
filtered_gdf = selected_var_gdf[selected_var_gdf.geometry.apply(lambda geom: geom.centroid.y < -60)] | ||
|
||
min_val, max_val, std_val = filtered_gdf[item].min(), filtered_gdf[item].max(), \ | ||
filtered_gdf[item].std() | ||
cbar_label = 'grid=%dX%d, min=%.3lf, max=%.3lf, std=%.3lf\n' % (resolution, resolution, min_val, max_val, std_val) | ||
|
||
filtered_gdf.plot(ax=ax, cmap='jet', column=item, legend=True, | ||
missing_kwds={'color': 'lightgrey'}, | ||
legend_kwds={'orientation': 'horizontal', 'shrink': 0.5, 'label': cbar_label}) | ||
|
||
plt.title("%s\n%s ch:%d %s" % (title,self.sensor, self.channel_no, item)) | ||
plt.savefig(os.path.join(output_path,"%s_ch%d_%s_region_%d.png" % (self.sensor, self.channel_no, item, region))) | ||
plt.close() | ||
|
||
def is_polygon_in_polar_region(self, polygon, latitude_threshold): | ||
""" | ||
Check if a polygon is in the polar region based on a latitude threshold. | ||
""" | ||
# Get the centroid of the polygon | ||
centroid = polygon.centroid | ||
|
||
# Extract the latitude of the centroid | ||
centroid_latitude = centroid.y | ||
|
||
# Check if the latitude is above the threshold | ||
return centroid_latitude >= latitude_threshold | ||
|
||
def is_polygon_in_latitude_range(self, polygon, min_latitude, max_latitude): | ||
""" | ||
Check if a polygon is in the specified latitude range. | ||
""" | ||
# Get the centroid of the polygon | ||
centroid = polygon.centroid | ||
|
||
# Extract the latitude of the centroid | ||
centroid_latitude = centroid.y | ||
|
||
# Check if the latitude is within the specified range | ||
return min_latitude <= centroid_latitude <= max_latitude | ||
|
||
def list_variable_names(self, file_path): | ||
ds=xarray.open_dataset(file_path) | ||
print(ds.info()) | ||
|
||
def make_summary_plots(self, obs_files_path, sensor,var_name, start_date, end_date,QC_filter,output_path): | ||
self.sensor=sensor | ||
#read all obs files | ||
all_files = os.listdir(obs_files_path) | ||
#obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')] | ||
obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4') and "diag_%s_ges" % sensor in file] | ||
|
||
# get date time from file names. alternatively could get from attribute but that needs reading the entire nc4 | ||
files_date_times_df=pd.DataFrame() | ||
|
||
files_date_times=self._extract_date_times(obs_files) | ||
files_date_times_df["file_name"]=obs_files | ||
files_date_times_df["date_time"]=files_date_times | ||
files_date_times_df['date'] = pd.to_datetime(files_date_times_df['date_time'].dt.date) | ||
|
||
#read start date | ||
start_date = datetime.strptime(start_date, '%Y-%m-%d') | ||
end_date = datetime.strptime(end_date, '%Y-%m-%d') | ||
|
||
studied_cycle_files=files_date_times_df[((files_date_times_df["date"]>=start_date) & ((files_date_times_df["date"]<=end_date)))]["file_name"] | ||
Summary_results=[] | ||
|
||
#get unique channels from one of the files | ||
ds=xarray.open_dataset(studied_cycle_files[0]) | ||
unique_channels=np.unique(ds["Channel_Index"].data).tolist() | ||
|
||
for this_channel in unique_channels: | ||
this_channel_values = np.empty(shape=(0,)) | ||
for this_cycle_obs_file in studied_cycle_files: | ||
ds=xarray.open_dataset(this_cycle_obs_file) | ||
Combined_bool=ds["Channel_Index"].data==this_channel | ||
|
||
if QC_filter: | ||
QC_bool=ds["QC_Flag"].data==0 | ||
Combined_bool=Combined_bool*QC_bool | ||
|
||
this_cycle_var_values=ds[var_name].data[Combined_bool] | ||
this_channel_values = np.append(this_channel_values, this_cycle_var_values) | ||
|
||
Summary_results.append([this_channel,np.size(this_channel_values), np.std(this_channel_values),np.mean(this_channel_values) ]) | ||
|
||
Summary_resultsDF=pd.DataFrame(Summary_results, columns=["channel", "count", "std", 'mean']) | ||
# Plotting | ||
plt.figure(figsize=(10, 6)) | ||
plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50) | ||
plt.xlabel("Channel") | ||
plt.ylabel("Count") | ||
plt.title("%s %s"%((self.sensor,var_name))) | ||
plt.xticks(Summary_resultsDF["channel"]) | ||
plt.xticks(rotation=45) | ||
plt.grid(True) | ||
plt.tight_layout() | ||
plt.savefig(os.path.join(output_path,"%s_%s_sumamryCounts.png" % (self.sensor,var_name))) | ||
plt.close() | ||
|
||
# Plotting scatter plot for mean and std | ||
plt.figure(figsize=(10, 6)) | ||
plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["mean"], s=50, c='red', label="Mean") | ||
plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["std"], s=50, c='green', label="Std") | ||
plt.xlabel("Channel") | ||
plt.ylabel("Statistics") | ||
plt.title("%s %s"%((self.sensor,var_name))) | ||
plt.xticks(Summary_resultsDF["channel"]) | ||
plt.xticks(rotation=45) | ||
plt.grid(True) | ||
plt.tight_layout() | ||
plt.legend() | ||
plt.savefig(os.path.join(output_path, "%s_%s_mean_std.png" % (self.sensor, var_name))) | ||
|
||
return Summary_resultsDF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from SpatialTemporalStats import SpatialTemporalStats | ||
|
||
# Set input and output paths | ||
input_path= r"/scratch2/NCEPDEV/stmp1/Emily.Liu/GDAS-ops-radstat/data/" | ||
output_path=r'./Results' | ||
|
||
# Set sensor name | ||
sensor="atms_n20" | ||
|
||
# Set variable name and channel number | ||
var_name= "Obs_Minus_Forecast_adjusted" | ||
channel_no=1 | ||
|
||
# Set start and end dates | ||
start_date, end_date = '2023-03-01', '2023-03-10' | ||
|
||
# Set region (1: global, 2: polar region, 3: mid-latitudes region, 4:tropics region, 5:southern mid-latitudes region, 6:southern polar region) | ||
region=1 | ||
|
||
# Initialize SpatialTemporalStats object | ||
my_tool = SpatialTemporalStats() | ||
|
||
# Set resolution for grid generation | ||
resolution=2 | ||
|
||
# Generate grid | ||
my_tool.generate_grid(resolution) # Call generate_grid method | ||
# observation_gdf=my_tool.read_obs_values(input_path, sensor, var_name, channel_no, start_date, end_date) | ||
|
||
# Set QC filter | ||
QC_filter= True # should be always False or true | ||
|
||
# Set filter by variables | ||
filter_by_vars=[("Land_Fraction", 'lt', 0.9),] #list each case in a separate tuple inside this list. can be an empty list #options are 'lt' or 'gt' stands for 'less than' and 'greater than' | ||
#filter_by_vars=[] | ||
|
||
# Read observational values and perform analysis | ||
o_minus_f_gdf=my_tool.read_obs_values(input_path, sensor,var_name, channel_no, start_date, end_date,filter_by_vars,QC_filter) | ||
|
||
# Plot observations | ||
my_tool.plot_obs(o_minus_f_gdf, var_name, region, resolution, output_path) | ||
|
||
# Make summary plots | ||
summary_results=my_tool.make_summary_plots( input_path, sensor,var_name, start_date, end_date,QC_filter,output_path) | ||
|
||
# Print summary results | ||
print(summary_results) |